In [None]:
 save_checkpoint(checkpoint_state, filename=os.path.join(checkpoints_path, f'checkpoint_epoch_{epoch+1}.pth.tar'))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

def calculate_errors(true_spectra, generated_spectra):
    """
    Calculate the errors between true spectra and generated spectra.
    
    Parameters:
    true_spectra (numpy array): The true spectra data.
    generated_spectra (numpy array): The generated spectra data.
    
    Returns:
    numpy array: Errors between true and generated spectra.
    """
    return true_spectra - generated_spectra

def plot_error_distribution(errors, wavelengths):
    """
    Plot the error distribution across wavelengths.
    
    Parameters:
    errors (numpy array): Errors between true and generated spectra.
    wavelengths (numpy array): Wavelengths corresponding to the spectra data.
    """
    for i in range(errors.shape[1]):
        sns.histplot(errors[:, i], kde=True, label=f'Wavelength {wavelengths[i]}')
    plt.xlabel('Error')
    plt.ylabel('Frequency')
    plt.title('Error Distribution Across Wavelengths')
    plt.legend()
    plt.show()

def plot_generated_spectra_over_epochs(generated_spectra_over_epochs, epochs, wavelengths):
    """
    Plot the generated spectra over different epochs.
    
    Parameters:
    generated_spectra_over_epochs (list of numpy arrays): List of generated spectra at different epochs.
    epochs (list of int): List of epochs corresponding to the generated spectra.
    wavelengths (numpy array): Wavelengths corresponding to the spectra data.
    """
    plt.figure(figsize=(15, 10))
    for i, spectra in enumerate(generated_spectra_over_epochs):
        plt.plot(wavelengths, spectra, label=f'Epoch {epochs[i]}')
    plt.xlabel('Wavelength')
    plt.ylabel('Intensity')
    plt.title('Generated Spectra Over Epochs')
    plt.legend()
    plt.show()

# Example usage:

# Assuming true_spectra and generated_spectra are numpy arrays of shape (num_samples, num_wavelengths)
true_spectra = np.random.random((100, 50))  # Example true spectra
generated_spectra = np.random.random((100, 50))  # Example generated spectra

# Calculate errors
errors = calculate_errors(true_spectra, generated_spectra)

# Plot error distribution
wavelengths = np.linspace(400, 700, 50)  # Example wavelengths
plot_error_distribution(errors, wavelengths)

# Plot generated spectra over epochs
generated_spectra_over_epochs = [np.random.random(50) for _ in range(10)]  # Example generated spectra over epochs
epochs = list(range(0, 100, 10))
plot_generated_spectra_over_epochs(generated_spectra_over_epochs, epochs, wavelengths)