# First Run For Best Model

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.interpolate import interp1d
from scipy.sparse import csc_matrix, eye, diags
from scipy.sparse.linalg import spsolve
from scipy.optimize import curve_fit
import tensorflow as tf
import keras
from keras import layers, models, regularizers
import pyts
from pyts.image import GramianAngularField as GADF
import json
import random
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import seaborn as sns

# Set random seeds for reproducibility
def set_seed(seed=42):
    np.random.seed(seed)
    tf.random.set_seed(seed)
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)

# Set a fixed seed
SEED = 42
set_seed(SEED)

# Create experiment directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_dir = os.path.join('/kaggle/working/experiments', f'experiment_{timestamp}')
os.makedirs(experiment_dir, exist_ok=True)
model_dir = os.path.join(experiment_dir, 'models')
results_dir = os.path.join(experiment_dir, 'results')
os.makedirs(model_dir, exist_ok=True)
os.makedirs(results_dir, exist_ok=True)

# INPUT directory (read-only)
data_dir = '/kaggle/input/ethanol-methanol/data'

# OUTPUT directories (writable)
output_dir = '/kaggle/working'
synthetic_dir = os.path.join(output_dir, 'synthetic')
maps_dir = os.path.join(output_dir, 'maps')
labels_dir = os.path.join(output_dir, 'labels')
visualizations_dir = os.path.join(output_dir, 'visualizations')
problematic_spectra_dir = os.path.join(output_dir, 'problematic_spectra')
os.makedirs(synthetic_dir, exist_ok=True)
os.makedirs(maps_dir, exist_ok=True)
os.makedirs(labels_dir, exist_ok=True)
os.makedirs(visualizations_dir, exist_ok=True)
os.makedirs(problematic_spectra_dir, exist_ok=True)

# Save experiment configuration
config = {
    'seed': SEED,
    'spectrum_length': 880,
    'image_size': 64,
    'num_synthetic_per_type': 999,
    'num_types': 11,
    'total_spectra': 11000,
    'epochs': 10,
    'batch_size_1d': 64,
    'batch_size_2d': 32,
    'validation_split': 0.1,
    'test_size': 0.2,
    'growth_rate': 12,
    'num_classes': 11,
    'experiment_timestamp': timestamp
}
with open(os.path.join(experiment_dir, 'config.json'), 'w') as f:
    json.dump(config, f, indent=4)


In [None]:
# ============================================================================
# CONSTANTS: Chemically Meaningful Raman Bands
# ============================================================================
RAMAN_BANDS = {
    'ethanol_cc': {'range': (870, 890), 'name': 'Ethanol C-C stretch'},
    'methanol_co': {'range': (1020, 1055), 'name': 'Methanol C-O stretch'},
    'ethanol_co': {'range': (1055, 1100), 'name': 'Ethanol C-O stretch'},
    'ch_bend': {'range': (1450, 1480), 'name': 'C-H bending'},
    'ch_stretch': {'range': (2800, 3000), 'name': 'C-H stretch'},
    'oh_stretch': {'range': (3300, 3400), 'name': 'O-H stretch'},
}

# ============================================================================
# BASELINE FUNCTIONS
# ============================================================================
def poly_baseline(x, p, intensity, b):
    y = (x / len(x)) ** p + b
    return y * intensity / max(y)

def gaussian_baseline(x, mean, sd, intensity, b):
    y = np.exp(-(x - mean) ** 2 / (2 * sd ** 2)) / (sd * np.sqrt(2 * np.pi)) + b
    return y * intensity / max(y)

def pg_baseline(x, p, in1, mean, sd, in2, b):
    y1 = (x / len(x)) ** p + b
    y2 = np.exp(-(x - mean) ** 2 / (2 * sd ** 2)) / (sd * np.sqrt(2 * np.pi)) + b
    return y1 / max(y1) * in1 + y2 / max(y2) * in2

def mix_min_no(sp, baseline):
    return np.minimum(baseline, sp)

def iterative_fitting_with_bounds_no(sp, model, ite=10):
    fitted_baseline = np.zeros(sp.shape[0])
    x = np.linspace(1, sp.shape[0], sp.shape[0])
    tempb = sp
    torch_tempb = tf.expand_dims(tempb, axis=0)
    i = 0
    while i < ite:
        tadvice = model(torch_tempb)
        if tadvice[0][0] >= 0.5 and tadvice[0][1] >= 0.5:
            try:
                p, c = curve_fit(pg_baseline, x, tempb,
                                bounds=([1, 0.5, 0, 100, 0.5, -0.5], [3, 1, sp.shape[0], 600, 1, 0.5]),
                                maxfev=10000)
                fitted_baseline = pg_baseline(x, p[0], p[1], p[2], p[3], p[4], p[5])
            except RuntimeError:
                fitted_baseline = tempb
        elif tadvice[0][0] >= 0.5:
            try:
                p, c = curve_fit(poly_baseline, x, tempb,
                                bounds=([1, 0.5, -0.5], [3, 1, 0.5]),
                                maxfev=10000)
                fitted_baseline = poly_baseline(x, p[0], p[1], p[2])
            except RuntimeError:
                fitted_baseline = tempb
        elif tadvice[0][1] >= 0.5:
            try:
                p, c = curve_fit(gaussian_baseline, x, tempb,
                                bounds=([0, 100, 0.5, -0.5], [sp.shape[0], 600, 1, 0.5]),
                                maxfev=10000)
                fitted_baseline = gaussian_baseline(x, p[0], p[1], p[2], p[3])
            except RuntimeError:
                fitted_baseline = tempb
        tempb = mix_min_no(tempb, fitted_baseline)
        tempb_np = np.array(tempb)
        torch_tempb = tempb_np.reshape(1, sp.shape[0])
        i += 1
    return tempb

def create_baseline_model(input_shape=880):
    model = models.Sequential([
        layers.Input(shape=(input_shape,)),
        layers.Reshape((input_shape, 1)),
        layers.Conv1D(filters=16, kernel_size=5, strides=1, activation='relu'),
        layers.AveragePooling1D(pool_size=2, strides=2),
        layers.Flatten(),
        layers.Dense(100, activation='relu'),
        layers.Dense(2, activation='sigmoid')
    ])
    return model

def train_baseline_model(baseline_model, noise_data, epochs=10, batch_size=32):
    try:
        labels = np.load(os.path.join(data_dir, 'labels_noise_pure_182.npy'))
        print("Labels loaded from labels_noise_pure_182.npy successfully!")
    except Exception as e:
        print(f"Error loading labels: {e}. Using random labels.")
        labels = np.random.randint(0, 2, size=noise_data.shape[0])

    X = []
    y = []
    for i in range(noise_data.shape[0]):
        pure = noise_data[i, 0, 0, :, 0]
        noisy = noise_data[i, 0, 1, :, 0]
        X.append(noisy)
        y.append(labels[i])
    X = np.array(X)[:, :, np.newaxis]
    y = np.array(y)
    baseline_model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    baseline_model.fit(X, y, epochs=epochs, batch_size=batch_size, validation_split=0.1)
    baseline_model.save_weights(os.path.join(data_dir, 'model.weights.h5'))
    return baseline_model

# ============================================================================
# SPECTRUM PROCESSING FUNCTIONS
# ============================================================================
def normalize_spectrum(spectrum):
    spectrum = spectrum - np.min(spectrum)
    if np.max(spectrum) > 0:
        spectrum = spectrum / np.max(spectrum)
    return spectrum

def WhittakerSmooth(x, w, lambda_=1, differences=1):
    X = np.matrix(x)
    m = X.size
    E = eye(m, format='csc')
    for i in range(differences):
        E = E[1:] - E[:-1]
    W = diags(w, 0, shape=(m, m))
    A = csc_matrix(W + (lambda_ * E.T * E))
    B = csc_matrix(W * X.T)
    background = spsolve(A, B)
    return np.array(background)

def airPLS(x, lambda_=100, porder=1, itermax=15):
    """
    Adaptive iteratively reweighted Penalized Least Squares (airPLS) baseline correction.

    Fixed to handle edge cases:
    - Empty input spectra
    - Spectra entirely above baseline (no negative residuals)
    - Near-zero dssn values (division protection)
    """
    m = x.shape[0]
    if m == 0:
        raise ValueError("Input spectrum is empty")

    w = np.ones(m)
    lambda_ = max(50, min(500, 50 * np.std(x) / (np.mean(np.abs(x)) + 1e-6)))

    for i in range(1, itermax + 1):
        z = WhittakerSmooth(x, w, lambda_, porder)
        d = x - z

        neg_mask = d < 0
        neg_count = np.sum(neg_mask)

        # Handle case with no negative residuals (spectrum entirely above baseline)
        if neg_count == 0:
            return z

        dssn = np.abs(d[neg_mask].sum())

        # Avoid division by near-zero
        if dssn < 1e-10:
            return z

        if dssn < 0.001 * np.abs(x).sum():
            return z

        if i == itermax:
            print(f'WARNING: Max iteration reached! lambda_={lambda_:.2f}, dssn={dssn:.2e}')
            np.save(os.path.join(problematic_spectra_dir, f'problematic_spectrum_{np.random.randint(1000000)}.npy'), x)
            return WhittakerSmooth(x, np.ones(m), lambda_=50)

        w[d >= 0] = 0
        w[neg_mask] = np.exp(i * np.abs(d[neg_mask]) / dssn)
        w[0] = np.exp(i * np.abs(d[neg_mask]).max() / dssn)
        w[-1] = w[0]

    return z

def enhanced_baseline_correction(spectrum, baseline_model):
    """
    Baseline correction using hybrid DL model + airPLS refinement.
    Falls back to airPLS only if the hybrid method fails.
    """
    try:
        baseline = iterative_fitting_with_bounds_no(spectrum, baseline_model)
        fine_corrected = airPLS(baseline, lambda_=100, itermax=15)
        return np.clip(spectrum - fine_corrected, 0, None)
    except Exception as e:
        print(f"Error in baseline correction: {e}. Falling back to airPLS.")
        return np.clip(spectrum - airPLS(spectrum, lambda_=100, itermax=15), 0, None)

def interpolate_spectrum(spectrum, original_length, target_length=880):
    x_original = np.linspace(0, original_length - 1, original_length)
    x_target = np.linspace(0, original_length - 1, target_length)
    interpolator = interp1d(x_original, spectrum, kind='linear', fill_value="extrapolate")
    return interpolator(x_target)

def shift_spectrum(spectrum, shift):
    return np.roll(spectrum, shift)

def stretch_spectrum(spectrum, alpha):
    original_len = len(spectrum)
    new_len = int(original_len / alpha)
    if new_len < 1:
        new_len = 1
    if new_len > original_len * 10:
        new_len = original_len * 10
    x_original = np.linspace(0, original_len - 1, original_len)
    x_new = np.linspace(0, original_len - 1, new_len)
    interpolator = interp1d(x_original, spectrum, kind='linear', fill_value="extrapolate")
    stretched = interpolator(x_new)
    return interpolate_spectrum(stretched, new_len, original_len)

def generate_synthetic_spectrum(input_spectrum, noise_data, spectrum_length=880):
    """Generate synthetic spectrum with noise, baseline, and augmentation."""
    x_range = np.linspace(0, spectrum_length, spectrum_length)
    # Add noise from dataset
    noise_idx = np.random.randint(0, noise_data.shape[0])
    pure = noise_data[noise_idx, 0, 0, :, 0]
    noisy = noise_data[noise_idx, 0, 1, :, 0]
    noise = noisy - pure
    scale = np.random.uniform(1.0, 2.0)
    synthetic_spectrum = input_spectrum + noise * scale
    # Add Gaussian noise
    synthetic_spectrum += np.random.normal(0, 0.05 * np.std(input_spectrum), spectrum_length)
    # Add synthetic baseline
    baseline_type = np.random.choice(['poly', 'gaussian', 'none'], p=[0.3, 0.3, 0.4])
    if baseline_type == 'poly':
        baseline = poly_baseline(x_range, p=np.random.uniform(1.9, 2.1),
                                intensity=np.random.uniform(0.75, 0.8),
                                b=np.random.uniform(-0.1, 0.1))
        synthetic_spectrum += baseline
    elif baseline_type == 'gaussian':
        baseline = gaussian_baseline(x_range, mean=np.random.uniform(0, spectrum_length),
                                   sd=np.random.uniform(250, 300),
                                   intensity=np.random.uniform(0.75, 0.8),
                                   b=np.random.uniform(-0.1, 0.1))
        synthetic_spectrum += baseline
    # Probabilistic augmentation
    aug_type = np.random.choice(['none', 'shift', 'stretch'], p=[0.5, 0.25, 0.25])
    if aug_type == 'shift':
        shift = np.random.randint(-10, 11)
        synthetic_spectrum = shift_spectrum(synthetic_spectrum, shift)
    elif aug_type == 'stretch':
        alpha = np.random.uniform(0.5, 2.0)
        synthetic_spectrum = stretch_spectrum(synthetic_spectrum, alpha)
    return synthetic_spectrum

def create_gadf_map(spectrum, image_size=64):
    spectrum = normalize_spectrum(spectrum)
    spectrum = 2 * spectrum - 1
    target_length = image_size * (spectrum.shape[0] // image_size)
    if target_length != spectrum.shape[0]:
        spectrum = interpolate_spectrum(spectrum, spectrum.shape[0], target_length)
    gadf = GADF(image_size=image_size, method='difference')
    return gadf.fit_transform(spectrum.reshape(1, -1))[0][:, :, np.newaxis]

# ============================================================================
# PHYSICALLY MEANINGFUL METRICS (Task 3)
# ============================================================================
def wavenumber_to_index(wavenumber, wavenumbers):
    """Convert wavenumber to array index."""
    return np.argmin(np.abs(wavenumbers - wavenumber))

def calculate_snr(spectrum, signal_region, noise_region):
    """
    Calculate Signal-to-Noise Ratio.

    Args:
        spectrum: Input spectrum
        signal_region: Tuple (start_idx, end_idx) for signal peak
        noise_region: Tuple (start_idx, end_idx) for baseline/noise

    Returns:
        SNR in dB
    """
    signal = spectrum[signal_region[0]:signal_region[1]]
    noise = spectrum[noise_region[0]:noise_region[1]]

    signal_power = np.mean(signal ** 2)
    noise_power = np.var(noise)

    if noise_power < 1e-10:
        return np.inf

    snr_db = 10 * np.log10(signal_power / noise_power)
    return snr_db

def calculate_peak_ratios(spectrum, wavenumbers):
    """
    Calculate intensity ratios between characteristic peaks.

    Returns:
        Dictionary of peak ratios
    """
    ethanol_co_idx = wavenumber_to_index(1050, wavenumbers)
    methanol_co_idx = wavenumber_to_index(1030, wavenumbers)
    ch_stretch_idx = wavenumber_to_index(2900, wavenumbers)

    # Get local maxima around expected positions (+/-5 points)
    ethanol_co_peak = np.max(spectrum[max(0, ethanol_co_idx-5):ethanol_co_idx+5])
    methanol_co_peak = np.max(spectrum[max(0, methanol_co_idx-5):methanol_co_idx+5])
    ch_stretch_peak = np.max(spectrum[max(0, ch_stretch_idx-10):ch_stretch_idx+10])

    return {
        'ethanol_co_to_ch': ethanol_co_peak / (ch_stretch_peak + 1e-10),
        'methanol_co_to_ch': methanol_co_peak / (ch_stretch_peak + 1e-10),
        'ethanol_to_methanol_co': ethanol_co_peak / (methanol_co_peak + 1e-10)
    }

def evaluate_preprocessing_quality(raw_spectrum, processed_spectrum, wavenumbers):
    """Evaluate preprocessing quality with physically meaningful metrics."""
    # SNR calculation (signal: C-H stretch region, noise: flat region)
    signal_region = (673, 720)  # C-H stretch
    noise_region = (50, 100)    # Flat baseline region

    snr_raw = calculate_snr(raw_spectrum, signal_region, noise_region)
    snr_processed = calculate_snr(processed_spectrum, signal_region, noise_region)

    # Peak ratio preservation
    ratios_raw = calculate_peak_ratios(raw_spectrum, wavenumbers)
    ratios_processed = calculate_peak_ratios(processed_spectrum, wavenumbers)

    return {
        'snr_raw': snr_raw,
        'snr_processed': snr_processed,
        'snr_improvement': snr_processed - snr_raw,
        'peak_ratios_raw': ratios_raw,
        'peak_ratios_processed': ratios_processed
    }

# ============================================================================
# CHEMICALLY MEANINGFUL OCCLUSION ANALYSIS (Task 2)
# ============================================================================
from sklearn.metrics import accuracy_score

def chemically_meaningful_occlusion(model, X, y, wavenumbers, bands=None):
    """
    Occlusion analysis with non-overlapping chemically meaningful windows.

    Args:
        model: Trained classifier
        X: Input spectra (N, 880, 1)
        y: True labels
        wavenumbers: Wavenumber array
        bands: Dictionary of Raman bands

    Returns:
        Dictionary with occlusion results for each band
    """
    if bands is None:
        bands = RAMAN_BANDS

    results = {}
    y_pred_baseline = model.predict(X, verbose=0).argmax(axis=1)
    baseline_acc = accuracy_score(y, y_pred_baseline)

    for band_key, band_info in bands.items():
        start_wn, end_wn = band_info['range']
        start_idx = wavenumber_to_index(start_wn, wavenumbers)
        end_idx = wavenumber_to_index(end_wn, wavenumbers)

        X_occluded = X.copy()
        X_occluded[:, start_idx:end_idx, :] = 0

        y_pred = model.predict(X_occluded, verbose=0).argmax(axis=1)
        acc = accuracy_score(y, y_pred)

        results[band_key] = {
            'name': band_info['name'],
            'wavenumber_range': band_info['range'],
            'index_range': (start_idx, end_idx),
            'accuracy_drop': (baseline_acc - acc) * 100,
            'occluded_accuracy': acc * 100
        }

    return {'bands': results, 'baseline_accuracy': baseline_acc * 100}

def plot_occlusion_analysis(occlusion_results, output_path):
    """Visualize occlusion analysis results."""
    bands = occlusion_results['bands']
    baseline = occlusion_results['baseline_accuracy']

    fig, ax = plt.subplots(figsize=(12, 6))

    band_names = [bands[k]['name'] for k in bands]
    acc_drops = [bands[k]['accuracy_drop'] for k in bands]
    colors = ['red' if d > 5 else 'orange' if d > 2 else 'green' for d in acc_drops]

    bars = ax.bar(band_names, acc_drops, color=colors, edgecolor='black')

    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    ax.set_xlabel('Raman Band', fontsize=12)
    ax.set_ylabel('Accuracy Drop (%)', fontsize=12)
    ax.set_title(f'Chemically Meaningful Occlusion Analysis\n(Baseline Accuracy: {baseline:.1f}%)', fontsize=14)
    plt.xticks(rotation=45, ha='right')

    # Add value labels on bars
    for bar, val in zip(bars, acc_drops):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.2,
                f'{val:.1f}%', ha='center', va='bottom', fontsize=10)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

# ============================================================================
# VISUALIZATION FUNCTIONS (Tasks 4-8)
# ============================================================================
def plot_representative_spectra_by_concentration(spectra_data, wavenumbers, ratios_dict,
                                                  output_path, figsize=(14, 10)):
    """Plot representative spectra showing spectral evolution with concentration."""
    fig, axes = plt.subplots(2, 1, figsize=figsize)

    # Colormap for concentration gradient
    colors = plt.cm.RdYlGn(np.linspace(0, 1, 11))

    # Plot all concentrations overlaid
    ax1 = axes[0]
    for name, spectrum in spectra_data.items():
        ratio = ratios_dict.get(name, 0.5)
        label = f"{int(ratio*100)}% Ethanol"
        ax1.plot(wavenumbers, normalize_spectrum(spectrum),
                 color=colors[int(ratio*10)], label=label, alpha=0.8)

    ax1.set_xlabel('Wavenumber (cm$^{-1}$)')
    ax1.set_ylabel('Normalized Intensity')
    ax1.set_title('Spectral Evolution: Methanol to Ethanol')
    ax1.legend(loc='upper right', fontsize=8, ncol=2)
    ax1.grid(True, alpha=0.3)

    # Highlight characteristic regions
    ax1.axvspan(1020, 1055, color='red', alpha=0.1, label='Methanol C-O')
    ax1.axvspan(1055, 1100, color='green', alpha=0.1, label='Ethanol C-O')

    # Waterfall plot
    ax2 = axes[1]
    offset = 0
    for name, spectrum in spectra_data.items():
        ratio = ratios_dict.get(name, 0.5)
        ax2.plot(wavenumbers, normalize_spectrum(spectrum) + offset,
                 color=colors[int(ratio*10)], linewidth=1)
        ax2.text(wavenumbers[-1] + 50, offset + 0.5, f"{int(ratio*100)}%", fontsize=8)
        offset += 1.2

    ax2.set_xlabel('Wavenumber (cm$^{-1}$)')
    ax2.set_ylabel('Intensity (offset)')
    ax2.set_title('Waterfall Plot: Concentration Series')

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def visualize_augmentation_effects(original_spectrum, noise_data, wavenumbers, output_path):
    """Visualize noise types and baseline simulations used in augmentation."""
    fig, axes = plt.subplots(3, 2, figsize=(14, 12))
    spectrum_length = len(original_spectrum)
    x_range = np.linspace(0, spectrum_length, spectrum_length)

    # (a) Original vs Gaussian noise
    ax = axes[0, 0]
    np.random.seed(42)
    noisy = original_spectrum + np.random.normal(0, 0.05, len(original_spectrum))
    ax.plot(wavenumbers, original_spectrum, 'b-', label='Original', alpha=0.7)
    ax.plot(wavenumbers, noisy, 'r-', label='+ Gaussian noise', alpha=0.7)
    ax.set_title('(a) Gaussian Noise Addition')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlabel('Wavenumber (cm$^{-1}$)')

    # (b) Original vs Dataset noise
    ax = axes[0, 1]
    noise_idx = 0
    pure = noise_data[noise_idx, 0, 0, :, 0]
    noisy_sample = noise_data[noise_idx, 0, 1, :, 0]
    real_noise = noisy_sample - pure
    ax.plot(wavenumbers, original_spectrum, 'b-', label='Original', alpha=0.7)
    ax.plot(wavenumbers, original_spectrum + real_noise * 0.5, 'r-', label='+ Real noise', alpha=0.7)
    ax.set_title('(b) Real Instrument Noise Addition')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlabel('Wavenumber (cm$^{-1}$)')

    # (c) Polynomial baseline
    ax = axes[1, 0]
    poly_bl = poly_baseline(x_range, p=2.0, intensity=0.8, b=0.0)
    ax.plot(wavenumbers, original_spectrum, 'b-', label='Original', alpha=0.7)
    ax.plot(wavenumbers, poly_bl, 'g--', label='Polynomial baseline', alpha=0.7)
    ax.plot(wavenumbers, original_spectrum + poly_bl, 'r-', label='With baseline', alpha=0.7)
    ax.set_title('(c) Polynomial Baseline Simulation')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlabel('Wavenumber (cm$^{-1}$)')

    # (d) Gaussian baseline (fluorescence)
    ax = axes[1, 1]
    gauss_bl = gaussian_baseline(x_range, mean=440, sd=280, intensity=0.8, b=0.0)
    ax.plot(wavenumbers, original_spectrum, 'b-', label='Original', alpha=0.7)
    ax.plot(wavenumbers, gauss_bl, 'g--', label='Gaussian baseline', alpha=0.7)
    ax.plot(wavenumbers, original_spectrum + gauss_bl, 'r-', label='With baseline', alpha=0.7)
    ax.set_title('(d) Gaussian Baseline (Fluorescence) Simulation')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlabel('Wavenumber (cm$^{-1}$)')

    # (e) Shift augmentation
    ax = axes[2, 0]
    shifted = shift_spectrum(original_spectrum, 10)
    ax.plot(wavenumbers, original_spectrum, 'b-', label='Original', alpha=0.7)
    ax.plot(wavenumbers, shifted, 'r-', label='Shifted (+10 pts)', alpha=0.7)
    ax.set_title('(e) Spectral Shift Augmentation')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlabel('Wavenumber (cm$^{-1}$)')

    # (f) Stretch augmentation
    ax = axes[2, 1]
    stretched = stretch_spectrum(original_spectrum, 1.5)
    ax.plot(wavenumbers, original_spectrum, 'b-', label='Original', alpha=0.7)
    ax.plot(wavenumbers, stretched, 'r-', label='Stretched (a=1.5)', alpha=0.7)
    ax.set_title('(f) Spectral Stretch Augmentation')
    ax.legend()
    ax.grid(True, alpha=0.3)
    ax.set_xlabel('Wavenumber (cm$^{-1}$)')

    plt.suptitle('Data Augmentation Techniques', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_preprocessing_comparison(raw_spectrum, hybrid_corrected, airpls_corrected,
                                   wavenumbers, output_path):
    """Compare preprocessing methods: Raw, Hybrid DL, airPLS."""
    fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)

    # (a) Raw with fluorescence baseline
    ax = axes[0]
    ax.plot(wavenumbers, raw_spectrum, 'b-', linewidth=1.5)
    ax.fill_between(wavenumbers, 0, raw_spectrum, alpha=0.3)
    ax.set_ylabel('Intensity')
    ax.set_title('(a) Raw Raman Spectrum with Fluorescence Baseline')
    ax.grid(True, alpha=0.3)

    # (b) Hybrid DL preprocessing
    ax = axes[1]
    ax.plot(wavenumbers, hybrid_corrected, 'g-', linewidth=1.5)
    ax.fill_between(wavenumbers, 0, hybrid_corrected, alpha=0.3, color='green')
    ax.set_ylabel('Intensity')
    ax.set_title('(b) After Hybrid Deep Learning Preprocessing')
    ax.grid(True, alpha=0.3)

    # (c) airPLS only
    ax = axes[2]
    ax.plot(wavenumbers, airpls_corrected, 'r-', linewidth=1.5)
    ax.fill_between(wavenumbers, 0, airpls_corrected, alpha=0.3, color='red')
    ax.set_xlabel('Wavenumber (cm$^{-1}$)')
    ax.set_ylabel('Intensity')
    ax.set_title('(c) After airPLS Baseline Correction')
    ax.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def visualize_gadf_samples(X_2d, labels, output_path):
    """Generate GADF images for pure methanol, 50:50 mixture, and pure ethanol."""
    # Find representative samples
    labels_arr = np.array(labels)
    pure_methanol_idx = np.where(labels_arr == 0)[0][0]
    mixture_idx = np.where(labels_arr == 5)[0][0]
    pure_ethanol_idx = np.where(labels_arr == 10)[0][0]

    samples = [
        (pure_methanol_idx, 'Pure Methanol (Class 0)', '100% Methanol'),
        (mixture_idx, '50:50 Mixture (Class 5)', '50% Ethanol / 50% Methanol'),
        (pure_ethanol_idx, 'Pure Ethanol (Class 10)', '100% Ethanol')
    ]

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    for ax, (idx, title, subtitle) in zip(axes, samples):
        gadf_img = X_2d[idx, :, :, 0]
        im = ax.imshow(gadf_img, cmap='viridis', aspect='auto',
                       extent=[0, 64, 64, 0])
        ax.set_title(f'{title}\n{subtitle}', fontsize=12)
        ax.set_xlabel('GADF Column')
        ax.set_ylabel('GADF Row')
        plt.colorbar(im, ax=ax, fraction=0.046)

    plt.suptitle('64x64 GADF Representations of Raman Spectra', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_peak_preservation_snr(raw_spectrum, processed_spectrum, wavenumbers,
                                metrics, output_path):
    """Illustrate peak preservation and SNR improvement."""
    fig, axes = plt.subplots(2, 1, figsize=(14, 10))

    # Top: Before preprocessing
    ax = axes[0]
    ax.plot(wavenumbers, raw_spectrum, 'b-', linewidth=1.5, label='Raw Spectrum')

    # Annotate peaks
    for band_name, wn in [('C-O (EtOH)', 1050), ('C-O (MeOH)', 1030), ('C-H', 2900)]:
        idx = wavenumber_to_index(wn, wavenumbers)
        if idx < len(raw_spectrum):
            ax.annotate(band_name, xy=(wn, raw_spectrum[idx]),
                        xytext=(wn, raw_spectrum[idx] + 0.1),
                        arrowprops=dict(arrowstyle='->', color='red'),
                        fontsize=10, ha='center')

    ax.set_ylabel('Intensity')
    snr_raw = metrics.get('snr_raw', 0)
    ax.set_title(f'Before Preprocessing (SNR = {snr_raw:.1f} dB)')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Bottom: After preprocessing
    ax = axes[1]
    ax.plot(wavenumbers, processed_spectrum, 'g-', linewidth=1.5, label='Processed Spectrum')

    for band_name, wn in [('C-O (EtOH)', 1050), ('C-O (MeOH)', 1030), ('C-H', 2900)]:
        idx = wavenumber_to_index(wn, wavenumbers)
        if idx < len(processed_spectrum):
            ax.annotate(band_name, xy=(wn, processed_spectrum[idx]),
                        xytext=(wn, processed_spectrum[idx] + 0.1),
                        arrowprops=dict(arrowstyle='->', color='red'),
                        fontsize=10, ha='center')

    ax.set_xlabel('Wavenumber (cm$^{-1}$)')
    ax.set_ylabel('Intensity')
    snr_processed = metrics.get('snr_processed', 0)
    snr_improvement = metrics.get('snr_improvement', 0)
    ax.set_title(f'After Preprocessing (SNR = {snr_processed:.1f} dB, Improvement = {snr_improvement:.1f} dB)')
    ax.legend()
    ax.grid(True, alpha=0.3)

    # Add peak ratio comparison text box
    ratios_raw = metrics.get('peak_ratios_raw', {})
    ratios_proc = metrics.get('peak_ratios_processed', {})
    textstr = (f"Peak Ratio Preservation:\n"
               f"EtOH C-O/C-H: {ratios_raw.get('ethanol_co_to_ch', 0):.3f} -> "
               f"{ratios_proc.get('ethanol_co_to_ch', 0):.3f}\n"
               f"MeOH C-O/C-H: {ratios_raw.get('methanol_co_to_ch', 0):.3f} -> "
               f"{ratios_proc.get('methanol_co_to_ch', 0):.3f}")
    ax.text(0.02, 0.98, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

def plot_all_training_histories(histories, output_path):
    """Plot training curves for all models."""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    model_names = ['densenet_1d', 'densenet_2d', 'resnet_1d', 'resnet_2d']
    titles = ['DenseNet 1D', 'DenseNet 2D (GADF)', 'ResNet 1D', 'ResNet 2D (GADF)']

    for ax, name, title in zip(axes.flat, model_names, titles):
        if name not in histories:
            ax.text(0.5, 0.5, f'{title}\n(No data)', ha='center', va='center')
            ax.set_title(title)
            continue
        h = histories[name]

        # Plot loss and accuracy
        ax2 = ax.twinx()
        l1, = ax.plot(h['loss'], 'b-', label='Train Loss')
        l2, = ax.plot(h.get('val_loss', []), 'b--', label='Val Loss')
        l3, = ax2.plot(h['accuracy'], 'r-', label='Train Acc')
        l4, = ax2.plot(h.get('val_accuracy', []), 'r--', label='Val Acc')

        ax.set_xlabel('Epoch')
        ax.set_ylabel('Loss', color='blue')
        ax2.set_ylabel('Accuracy', color='red')
        ax.set_title(title)

        lines = [l1, l2, l3, l4]
        labels = [l.get_label() for l in lines]
        ax.legend(lines, labels, loc='center right', fontsize=8)
        ax.grid(True, alpha=0.3)

    plt.suptitle('Model Training History', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()

# ============================================================================
# DATA LOADING FUNCTIONS
# ============================================================================
def load_excel_data():
    excel_path = os.path.join(data_dir, 'Ethanol_Methanol.xlsx')
    try:
        df = pd.read_excel(excel_path, usecols='A:L')
        raman_shift = df['Raman Shift (cm-1)'].values
        spectra_data = {
            'ethanol': df['Ethanol'].values,
            'methanol': df['Methanol'].values,
            'EM1_a': df['EM1_a'].values,
            'EM2_a': df['EM2_a'].values,
            'EM3_a': df['EM3_a'].values,
            'EM4_a': df['EM4_a'].values,
            'EM5_a': df['EM5_a'].values,
            'EM6_a': df['EM6_a'].values,
            'EM7_a': df['EM7_a'].values,
            'EM8_a': df['EM8_a'].values,
            'EM9_a': df['EM9_a'].values
        }
        for key in spectra_data:
            spectrum = spectra_data[key]
            spectrum = spectrum[~np.isnan(spectrum)]
            spectra_data[key] = interpolate_spectrum(spectrum, len(spectrum), 880)
        return spectra_data, raman_shift
    except Exception as e:
        print(f"Error loading data from {excel_path}: {e}")
        return {}, np.array([])

def load_noise_data():
    try:
        noise_data = np.load(os.path.join(data_dir, 'dataset_noise_pure_182.npy'))
        return noise_data
    except Exception as e:
        print(f"Error loading noise data: {e}")
        return np.array([])

# ============================================================================
# MODEL DEFINITIONS
# ============================================================================
def build_1d_densenet(input_shape=(880, 1), num_classes=11, growth_rate=12):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(48, 7, padding='same', activation='relu', kernel_regularizer=regularizers.l2(0.0005))(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    def dense_block(x, num_layers, filters):
        for _ in range(num_layers):
            y = layers.BatchNormalization()(x)
            y = layers.Activation('relu')(y)
            y = layers.Conv1D(filters, 3, padding='same', kernel_regularizer=regularizers.l2(0.0005))(y)
            x = layers.Concatenate()([x, y])
        return x
    def transition_layer(x):
        filters = x.shape[-1]
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv1D(filters // 2, 1, padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
        x = layers.MaxPooling1D(pool_size=2)(x)
        return x
    for _ in range(3):
        x = dense_block(x, num_layers=4, filters=growth_rate)
        x = transition_layer(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs)

def build_2d_densenet(input_shape=(64, 64, 1), num_classes=11, growth_rate=12):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(48, 3, padding='same', activation='relu', kernel_regularizer=regularizers.l2(0.0005))(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    def dense_block(x, num_layers, filters):
        for _ in range(num_layers):
            y = layers.BatchNormalization()(x)
            y = layers.Activation('relu')(y)
            y = layers.Conv2D(filters, 3, padding='same', kernel_regularizer=regularizers.l2(0.0005))(y)
            x = layers.Concatenate()([x, y])
        return x
    def transition_layer(x):
        filters = x.shape[-1]
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
        x = layers.Conv2D(filters // 2, 1, padding='same', kernel_regularizer=regularizers.l2(0.0005))(x)
        x = layers.MaxPooling2D(pool_size=(2, 2))(x)
        return x
    for _ in range(3):
        x = dense_block(x, num_layers=4, filters=growth_rate)
        x = transition_layer(x)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.4)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs)

def build_1d_resnet(input_shape=(880, 1), num_classes=11):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv1D(64, 5, padding='same', activation='relu', kernel_regularizer=regularizers.l2(0.0001))(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling1D(pool_size=2)(x)
    def residual_block(x, filters, kernel_size=3):
        shortcut = x
        x = layers.Conv1D(filters, kernel_size, padding='same', activation='relu', kernel_regularizer=regularizers.l2(0.0001))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv1D(filters, kernel_size, padding='same', activation='relu', kernel_regularizer=regularizers.l2(0.0001))(x)
        x = layers.BatchNormalization()(x)
        if shortcut.shape[-1] != filters:
            shortcut = layers.Conv1D(filters, 1, padding='same')(shortcut)
        x = layers.Add()([shortcut, x])
        x = layers.Activation('relu')(x)
        return x
    x = residual_block(x, 64)
    x = residual_block(x, 64)
    x = layers.MaxPooling1D(pool_size=2)(x)
    x = residual_block(x, 128)
    x = residual_block(x, 128)
    x = residual_block(x, 128)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs)

def build_2d_resnet(input_shape=(64, 64, 1), num_classes=11):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(32, 3, padding='same', activation='relu', kernel_regularizer=regularizers.l2(0.0005))(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    def residual_block(x, filters, kernel_size=3):
        shortcut = x
        x = layers.Conv2D(filters, kernel_size, padding='same', activation='relu', kernel_regularizer=regularizers.l2(0.0005))(x)
        x = layers.BatchNormalization()(x)
        x = layers.Conv2D(filters, kernel_size, padding='same')(x)
        x = layers.BatchNormalization()(x)
        if shortcut.shape[-1] != filters:
            shortcut = layers.Conv2D(filters, 1, padding='same')(shortcut)
        x = layers.Add()([shortcut, x])
        x = layers.Activation('relu')(x)
        return x
    x = residual_block(x, 32)
    x = residual_block(x, 32)
    x = layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs)

# ============================================================================
# CONFUSION MATRIX
# ============================================================================
def plot_confusion_matrix(y_true, y_pred, title, filename):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                xticklabels=[f"{i*10}% Ethanol" for i in range(11)],
                yticklabels=[f"{i*10}% Ethanol" for i in range(11)])
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig(os.path.join(visualizations_dir, filename), dpi=300, bbox_inches='tight')
    plt.close()

# ============================================================================
# DATA AUGMENTATION LAYERS
# ============================================================================
data_augmentation_1d = models.Sequential([
    layers.Lambda(lambda x: x + tf.random.normal(tf.shape(x), mean=0.0, stddev=0.05)),
    layers.Lambda(lambda x: x * tf.random.uniform((), 0.8, 1.2)),
    layers.Lambda(lambda x: tf.roll(x, shift=tf.random.uniform((), -5, 5, dtype=tf.int32), axis=1))
])
data_augmentation_2d = models.Sequential([
    layers.Lambda(lambda x: x + tf.random.normal(tf.shape(x), mean=0.0, stddev=0.05)),
    layers.Lambda(lambda x: x * tf.random.uniform((), 0.8, 1.2)),
    layers.Lambda(lambda x: tf.roll(x, shift=tf.random.uniform((), -5, 5, dtype=tf.int32), axis=1))
])

# ============================================================================
# LOAD AND PROCESS DATA
# ============================================================================
spectra_data, raman_shift = load_excel_data()
noise_data = load_noise_data()
if not spectra_data or noise_data.size == 0:
    raise FileNotFoundError("Cannot load data from Excel file or noise file.")

# Initialize and train baseline model
baseline_model = create_baseline_model(input_shape=880)
try:
    baseline_model.load_weights(os.path.join(data_dir, 'model.weights.h5'))
    print("Baseline model weights loaded successfully!")
except Exception as e:
    print(f"Error loading weights: {e}. Training baseline model.")
    baseline_model = train_baseline_model(baseline_model, noise_data)

# ============================================================================
# GENERATE SYNTHETIC DATA
# ============================================================================
X_1d = []
X_2d = []
labels = []
sample_ids = []
example_spectra = {}
total_spectra = 0
spectrum_length = 880

ratio_to_label = {0.0: 0, 0.1: 1, 0.2: 2, 0.3: 3, 0.4: 4, 0.5: 5, 0.6: 6, 0.7: 7, 0.8: 8, 0.9: 9, 1.0: 10}
ratios = {
    'ethanol': 1.0,
    'methanol': 0.0,
    'EM1_a': 0.9,
    'EM2_a': 0.8,
    'EM3_a': 0.7,
    'EM4_a': 0.6,
    'EM5_a': 0.5,
    'EM6_a': 0.4,
    'EM7_a': 0.3,
    'EM8_a': 0.2,
    'EM9_a': 0.1
}

# Store raw and corrected spectra for visualization
raw_spectra_for_viz = {}
corrected_spectra_for_viz = {}

print("Starting synthetic data generation...")
for spectrum_type, ethanol_ratio in ratios.items():
    print(f"Processing {spectrum_type} with ethanol ratio {ethanol_ratio}...")
    input_spectrum = spectra_data[spectrum_type]
    normalized_raw = normalize_spectrum(input_spectrum)

    # Store for visualization
    raw_spectra_for_viz[spectrum_type] = normalized_raw.copy()

    corrected_spectrum = enhanced_baseline_correction(normalized_raw, baseline_model)
    corrected_spectrum = normalize_spectrum(corrected_spectrum)

    corrected_spectra_for_viz[spectrum_type] = corrected_spectrum.copy()

    X_1d.append(corrected_spectrum)
    X_2d.append(create_gadf_map(corrected_spectrum, image_size=64))
    label_idx = ratio_to_label[ethanol_ratio]
    labels.append(label_idx)
    sample_ids.append(f"{spectrum_type}_original")
    total_spectra += 1
    if total_spectra == 1 or ethanol_ratio not in example_spectra:
        example_spectra[ethanol_ratio] = {
            'raw': normalized_raw,
            'corrected': corrected_spectrum,
            'ethanol': spectra_data['ethanol'] if ethanol_ratio > 0 else None,
            'methanol': spectra_data['methanol'] if ethanol_ratio < 1 else None
        }
    for i in range(999):
        synthetic_spectrum = generate_synthetic_spectrum(normalized_raw, noise_data, spectrum_length)
        corrected_spectrum = enhanced_baseline_correction(synthetic_spectrum, baseline_model)
        corrected_spectrum = normalize_spectrum(corrected_spectrum)
        X_1d.append(corrected_spectrum)
        X_2d.append(create_gadf_map(corrected_spectrum, image_size=64))
        labels.append(label_idx)
        sample_ids.append(f"{spectrum_type}_synthetic_{i}")
        total_spectra += 1
        if i == 0:
            example_spectra[ethanol_ratio] = {
                'raw': synthetic_spectrum,
                'corrected': corrected_spectrum,
                'ethanol': spectra_data['ethanol'] if ethanol_ratio > 0 else None,
                'methanol': spectra_data['methanol'] if ethanol_ratio < 1 else None
            }
        if total_spectra % 1000 == 0:
            print(f"Processed {total_spectra} spectra.")

print(f"Total spectra: {total_spectra}")

# Convert to numpy arrays
X_1d = np.array(X_1d)[:, :, np.newaxis]
X_2d = np.array(X_2d)
labels_df = pd.DataFrame({'label': labels, 'sample_id': sample_ids})

print("X_1d shape:", X_1d.shape)
print("X_2d shape:", X_2d.shape)
print("labels_df shape:", labels_df.shape)
print("Label distribution:\n", labels_df["label"].value_counts())

# Save data
labels_df.to_csv(os.path.join(labels_dir, "labels.csv"), index=False)
np.save(os.path.join(synthetic_dir, "synthetic_1d.npy"), X_1d)
np.save(os.path.join(maps_dir, "spectral_maps_gadf.npy"), X_2d)

# ============================================================================
# GENERATE VISUALIZATIONS (Tasks 4-8)
# ============================================================================
wavenumbers = np.linspace(500, 3500, 880)

# Task 4: Representative spectra visualization
print("Generating representative spectra visualization...")
plot_representative_spectra_by_concentration(
    spectra_data, wavenumbers, ratios,
    os.path.join(visualizations_dir, 'representative_spectra.png')
)

# Task 5: Augmentation effects visualization
print("Generating augmentation effects visualization...")
sample_spectrum = normalize_spectrum(spectra_data['EM5_a'])
visualize_augmentation_effects(
    sample_spectrum, noise_data, wavenumbers,
    os.path.join(visualizations_dir, 'augmentation_effects.png')
)

# Task 6: Preprocessing comparison
print("Generating preprocessing comparison visualization...")
sample_raw = normalize_spectrum(spectra_data['EM5_a'])
sample_synthetic = sample_raw + gaussian_baseline(
    np.linspace(0, 880, 880), mean=440, sd=280, intensity=0.8, b=0.0
)
sample_hybrid = enhanced_baseline_correction(sample_synthetic, baseline_model)
sample_airpls = np.clip(sample_synthetic - airPLS(sample_synthetic), 0, None)
plot_preprocessing_comparison(
    sample_synthetic, normalize_spectrum(sample_hybrid), normalize_spectrum(sample_airpls),
    wavenumbers, os.path.join(visualizations_dir, 'preprocessing_comparison.png')
)

# Task 7: GADF visualization
print("Generating GADF visualization...")
visualize_gadf_samples(X_2d, labels, os.path.join(visualizations_dir, 'gadf_samples.png'))

# Task 8: Peak preservation and SNR improvement
print("Generating peak preservation/SNR visualization...")
sample_raw = normalize_spectrum(spectra_data['EM5_a'])
sample_corrected = normalize_spectrum(corrected_spectra_for_viz.get('EM5_a', sample_raw))
metrics = evaluate_preprocessing_quality(sample_raw, sample_corrected, wavenumbers)
plot_peak_preservation_snr(
    sample_raw, sample_corrected, wavenumbers, metrics,
    os.path.join(visualizations_dir, 'peak_preservation_snr.png')
)

# Legacy visualization
def plot_spectra_comparison(wavenumbers, raw_spectrum, corrected_spectrum, ethanol_spectrum, methanol_spectrum, title, filename):
    plt.figure(figsize=(15, 10))
    plt.subplot(2, 1, 1)
    plt.plot(wavenumbers, raw_spectrum, label="Raw Spectrum", color='blue')
    plt.axvspan(1000, 1020, color='red', alpha=0.2, label="Methanol Region")
    plt.axvspan(870, 890, color='green', alpha=0.2, label="Ethanol Region")
    plt.xlabel("Wavenumber (cm$^{-1}$)")
    plt.ylabel("Normalized Intensity")
    plt.title(f"Raw Spectrum ({title})")
    plt.grid(True)
    plt.legend()
    plt.subplot(2, 1, 2)
    plt.plot(wavenumbers, corrected_spectrum, label="Corrected Spectrum", color='orange')
    if ethanol_spectrum is not None:
        plt.plot(wavenumbers, normalize_spectrum(ethanol_spectrum), label="Ethanol Component", color='green', linestyle='--')
    if methanol_spectrum is not None:
        plt.plot(wavenumbers, normalize_spectrum(methanol_spectrum), label="Methanol Component", color='red', linestyle='--')
    plt.axvspan(1000, 1020, color='red', alpha=0.2)
    plt.axvspan(870, 890, color='green', alpha=0.2)
    plt.xlabel("Wavenumber (cm$^{-1}$)")
    plt.ylabel("Normalized Intensity")
    plt.title(f"Baseline Corrected Spectrum ({title})")
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(visualizations_dir, filename), dpi=300, bbox_inches='tight')
    plt.close()

for ratio in example_spectra:
    title = f"Ratio Ethanol/Methanol = {ratio:.2f}/{1-ratio:.2f}" if ratio < 1.0 else "Pure Ethanol" if ratio == 1.0 else "Pure Methanol"
    plot_spectra_comparison(
        wavenumbers,
        example_spectra[ratio]['raw'],
        example_spectra[ratio]['corrected'],
        example_spectra[ratio]['ethanol'],
        example_spectra[ratio]['methanol'],
        title,
        f"spectra_comparison_ratio_{ratio:.2f}.png"
    )

# ============================================================================
# SPLIT DATA
# ============================================================================
y = labels_df["label"].values
X_1d_train, X_1d_test, y_train, y_test = train_test_split(X_1d, y, test_size=0.2, random_state=42)
X_2d_train, X_2d_test, y_train_2d, y_test_2d = train_test_split(X_2d, y, test_size=0.2, random_state=42)

# Compute class weights
class_weights = compute_class_weight('balanced', classes=np.arange(11), y=y)
class_weight = {i: w for i, w in enumerate(class_weights)}

In [None]:
# ============================================================================
# TRAIN MODELS WITH HISTORY CAPTURE (Task 9)
# ============================================================================
early_stopping = keras.callbacks.EarlyStopping(patience=10, restore_best_weights=True)

# Dictionary to store training histories
histories = {}

# DenseNet 1D
print("Training DenseNet 1D...")
tf.keras.backend.clear_session()
lr_schedule_1d_densenet = keras.optimizers.schedules.CosineDecay(initial_learning_rate=0.001, decay_steps=10*len(X_1d_train)//64)
optimizer_1d_densenet = keras.optimizers.Adam(learning_rate=lr_schedule_1d_densenet)
densenet_1d = models.Sequential([data_augmentation_1d, build_1d_densenet()])
densenet_1d.compile(optimizer=optimizer_1d_densenet, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history_densenet_1d = densenet_1d.fit(X_1d_train, y_train, validation_split=0.1, epochs=10, batch_size=64,
                callbacks=[keras.callbacks.ModelCheckpoint(os.path.join(model_dir, "best_densenet_1d.keras"), save_best_only=True), early_stopping],
                class_weight=class_weight)
histories['densenet_1d'] = history_densenet_1d.history

# DenseNet 2D
print("Training DenseNet 2D...")
tf.keras.backend.clear_session()
lr_schedule_2d_densenet = keras.optimizers.schedules.CosineDecay(initial_learning_rate=0.001, decay_steps=10*len(X_2d_train)//32)
optimizer_2d_densenet = keras.optimizers.Adam(learning_rate=lr_schedule_2d_densenet)
densenet_2d = models.Sequential([data_augmentation_2d, build_2d_densenet()])
densenet_2d.compile(optimizer=optimizer_2d_densenet, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history_densenet_2d = densenet_2d.fit(X_2d_train, y_train_2d, validation_split=0.1, epochs=10, batch_size=32,
                callbacks=[keras.callbacks.ModelCheckpoint(os.path.join(model_dir, "best_densenet_2d.keras"), save_best_only=True), early_stopping],
                class_weight=class_weight)
histories['densenet_2d'] = history_densenet_2d.history

# ResNet 1D
print("Training ResNet 1D...")
tf.keras.backend.clear_session()
lr_schedule_1d_resnet = keras.optimizers.schedules.CosineDecay(initial_learning_rate=0.001, decay_steps=10*len(X_1d_train)//64)
optimizer_1d_resnet = keras.optimizers.Adam(learning_rate=lr_schedule_1d_resnet)
resnet_1d = models.Sequential([data_augmentation_1d, build_1d_resnet()])
resnet_1d.compile(optimizer=optimizer_1d_resnet, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history_resnet_1d = resnet_1d.fit(X_1d_train, y_train, validation_split=0.1, epochs=10, batch_size=64,
              callbacks=[keras.callbacks.ModelCheckpoint(os.path.join(model_dir, "best_resnet_1d.keras"), save_best_only=True), early_stopping],
              class_weight=class_weight)
histories['resnet_1d'] = history_resnet_1d.history

# ResNet 2D
print("Training ResNet 2D...")
tf.keras.backend.clear_session()
lr_schedule_2d_resnet = keras.optimizers.schedules.CosineDecay(initial_learning_rate=0.001, decay_steps=10*len(X_2d_train)//32)
optimizer_2d_resnet = keras.optimizers.Adam(learning_rate=lr_schedule_2d_resnet)
resnet_2d = models.Sequential([data_augmentation_2d, build_2d_resnet()])
resnet_2d.compile(optimizer=optimizer_2d_resnet, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
history_resnet_2d = resnet_2d.fit(X_2d_train, y_train_2d, validation_split=0.1, epochs=10, batch_size=32,
              callbacks=[keras.callbacks.ModelCheckpoint(os.path.join(model_dir, "best_resnet_2d.keras"), save_best_only=True), early_stopping],
              class_weight=class_weight)
histories['resnet_2d'] = history_resnet_2d.history

# Plot training histories
print("Generating training curves visualization...")
plot_all_training_histories(histories, os.path.join(visualizations_dir, 'training_curves.png'))

# Save histories for later analysis
import json
# Convert numpy arrays to lists for JSON serialization
histories_serializable = {}
for k, v in histories.items():
    histories_serializable[k] = {key: [float(x) for x in val] for key, val in v.items()}
with open(os.path.join(results_dir, 'training_histories.json'), 'w') as f:
    json.dump(histories_serializable, f, indent=2)
print("Training histories saved.")

In [None]:
# ============================================================================
# EVALUATION & OCCLUSION ANALYSIS (Task 2)
# ============================================================================

# Model predictions
y_pred_1d_densenet = densenet_1d.predict(X_1d_test)
y_pred_2d_densenet = densenet_2d.predict(X_2d_test)
y_pred_1d_resnet = resnet_1d.predict(X_1d_test)
y_pred_2d_resnet = resnet_2d.predict(X_2d_test)

y_pred_1d_densenet_labels = np.argmax(y_pred_1d_densenet, axis=1)
y_pred_2d_densenet_labels = np.argmax(y_pred_2d_densenet, axis=1)
y_pred_1d_resnet_labels = np.argmax(y_pred_1d_resnet, axis=1)
y_pred_2d_resnet_labels = np.argmax(y_pred_2d_resnet, axis=1)

# Calculate metrics
densenet_1d_acc = np.mean(y_pred_1d_densenet_labels == y_test)
densenet_2d_acc = np.mean(y_pred_2d_densenet_labels == y_test_2d)
resnet_1d_acc = np.mean(y_pred_1d_resnet_labels == y_test)
resnet_2d_acc = np.mean(y_pred_2d_resnet_labels == y_test_2d)

densenet_1d_precision = precision_score(y_test, y_pred_1d_densenet_labels, average='macro')
densenet_2d_precision = precision_score(y_test_2d, y_pred_2d_densenet_labels, average='macro')
resnet_1d_precision = precision_score(y_test, y_pred_1d_resnet_labels, average='macro')
resnet_2d_precision = precision_score(y_test_2d, y_pred_2d_resnet_labels, average='macro')

densenet_1d_recall = recall_score(y_test, y_pred_1d_densenet_labels, average='macro')
densenet_2d_recall = recall_score(y_test_2d, y_pred_2d_densenet_labels, average='macro')
resnet_1d_recall = recall_score(y_test, y_pred_1d_resnet_labels, average='macro')
resnet_2d_recall = recall_score(y_test_2d, y_pred_2d_resnet_labels, average='macro')

densenet_1d_f1 = f1_score(y_test, y_pred_1d_densenet_labels, average='macro')
densenet_2d_f1 = f1_score(y_test_2d, y_pred_2d_densenet_labels, average='macro')
resnet_1d_f1 = f1_score(y_test, y_pred_1d_resnet_labels, average='macro')
resnet_2d_f1 = f1_score(y_test_2d, y_pred_2d_resnet_labels, average='macro')

print("\n" + "="*70)
print("EVALUATION RESULTS")
print("="*70)
print(f"DenseNet 1D - Accuracy: {densenet_1d_acc:.4f}, Precision: {densenet_1d_precision:.4f}, Recall: {densenet_1d_recall:.4f}, F1: {densenet_1d_f1:.4f}")
print(f"DenseNet 2D (GADF) - Accuracy: {densenet_2d_acc:.4f}, Precision: {densenet_2d_precision:.4f}, Recall: {densenet_2d_recall:.4f}, F1: {densenet_2d_f1:.4f}")
print(f"ResNet 1D - Accuracy: {resnet_1d_acc:.4f}, Precision: {resnet_1d_precision:.4f}, Recall: {resnet_1d_recall:.4f}, F1: {resnet_1d_f1:.4f}")
print(f"ResNet 2D (GADF) - Accuracy: {resnet_2d_acc:.4f}, Precision: {resnet_2d_precision:.4f}, Recall: {resnet_2d_recall:.4f}, F1: {resnet_2d_f1:.4f}")

# Generate confusion matrices
print("\nGenerating confusion matrices...")
plot_confusion_matrix(y_test, y_pred_1d_densenet_labels, "DenseNet 1D Confusion Matrix", "confusion_densenet_1d.png")
plot_confusion_matrix(y_test_2d, y_pred_2d_densenet_labels, "DenseNet 2D (GADF) Confusion Matrix", "confusion_densenet_2d.png")
plot_confusion_matrix(y_test, y_pred_1d_resnet_labels, "ResNet 1D Confusion Matrix", "confusion_resnet_1d.png")
plot_confusion_matrix(y_test_2d, y_pred_2d_resnet_labels, "ResNet 2D (GADF) Confusion Matrix", "confusion_resnet_2d.png")

# ============================================================================
# CHEMICALLY MEANINGFUL OCCLUSION ANALYSIS (Task 2)
# ============================================================================
print("\n" + "="*70)
print("CHEMICALLY MEANINGFUL OCCLUSION ANALYSIS")
print("="*70)
print("Performing occlusion analysis on best 1D model (DenseNet 1D)...")

# Perform occlusion analysis on the best 1D model
occlusion_results = chemically_meaningful_occlusion(
    densenet_1d, X_1d_test, y_test, wavenumbers, RAMAN_BANDS
)

print(f"\nBaseline Accuracy: {occlusion_results['baseline_accuracy']:.2f}%")
print("\nAccuracy Drop by Raman Band:")
print("-" * 50)
for band_key, band_data in occlusion_results['bands'].items():
    print(f"  {band_data['name']:25s}: {band_data['accuracy_drop']:+6.2f}% "
          f"({band_data['wavenumber_range'][0]}-{band_data['wavenumber_range'][1]} cm^-1)")

# Plot occlusion analysis results
plot_occlusion_analysis(
    occlusion_results,
    os.path.join(visualizations_dir, 'occlusion_analysis.png')
)

# ============================================================================
# SAVE COMPREHENSIVE RESULTS
# ============================================================================
results = {
    'models': {
        'densenet_1d': {
            'accuracy': float(densenet_1d_acc),
            'precision': float(densenet_1d_precision),
            'recall': float(densenet_1d_recall),
            'f1': float(densenet_1d_f1)
        },
        'densenet_2d': {
            'accuracy': float(densenet_2d_acc),
            'precision': float(densenet_2d_precision),
            'recall': float(densenet_2d_recall),
            'f1': float(densenet_2d_f1)
        },
        'resnet_1d': {
            'accuracy': float(resnet_1d_acc),
            'precision': float(resnet_1d_precision),
            'recall': float(resnet_1d_recall),
            'f1': float(resnet_1d_f1)
        },
        'resnet_2d': {
            'accuracy': float(resnet_2d_acc),
            'precision': float(resnet_2d_precision),
            'recall': float(resnet_2d_recall),
            'f1': float(resnet_2d_f1)
        }
    },
    'occlusion_analysis': {
        'baseline_accuracy': occlusion_results['baseline_accuracy'],
        'bands': {k: {
            'name': v['name'],
            'wavenumber_range': v['wavenumber_range'],
            'accuracy_drop': v['accuracy_drop'],
            'occluded_accuracy': v['occluded_accuracy']
        } for k, v in occlusion_results['bands'].items()}
    },
    'preprocessing_quality': {
        'snr_raw': float(metrics['snr_raw']) if not np.isinf(metrics['snr_raw']) else 'inf',
        'snr_processed': float(metrics['snr_processed']) if not np.isinf(metrics['snr_processed']) else 'inf',
        'snr_improvement': float(metrics['snr_improvement']) if not np.isinf(metrics['snr_improvement']) else 'inf',
        'peak_ratios_raw': {k: float(v) for k, v in metrics['peak_ratios_raw'].items()},
        'peak_ratios_processed': {k: float(v) for k, v in metrics['peak_ratios_processed'].items()}
    }
}

with open(os.path.join(results_dir, 'comprehensive_results.json'), 'w') as f:
    json.dump(results, f, indent=2)

print("\n" + "="*70)
print("All visualizations saved to:", visualizations_dir)
print("Results saved to:", results_dir)
print("="*70)

# List generated visualizations
print("\nGenerated visualizations:")
for viz_file in [
    'representative_spectra.png',
    'augmentation_effects.png',
    'preprocessing_comparison.png',
    'gadf_samples.png',
    'peak_preservation_snr.png',
    'training_curves.png',
    'occlusion_analysis.png',
    'confusion_densenet_1d.png',
    'confusion_densenet_2d.png',
    'confusion_resnet_1d.png',
    'confusion_resnet_2d.png'
]:
    print(f"  - {viz_file}")