### RBM learning


In [1]:
import os
import re
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from io import StringIO

USE_TORCH = True


class ComplexRBM(nn.Module):
    """
    PyTorch implementation of RBM for learning shell interactions in complex-valued data.
    """
    def __init__(self, n_visible, n_hidden, k=1, learning_rate=0.01, momentum=0.5, weight_decay=1e-4):
        """
        Initialize the RBM.
        
        Args:
            n_visible: Number of visible units (2 * number of shells for complex data)
            n_hidden: Number of hidden units
            k: Number of Gibbs sampling steps (typically 1)
            learning_rate: Learning rate for gradient descent
            momentum: Momentum factor for gradient descent
            weight_decay: L2 regularization factor
        """
        super(ComplexRBM, self).__init__()
        
        # Model parameters
        self.W = nn.Parameter(torch.randn(n_visible, n_hidden) * 0.01)
        self.v_bias = nn.Parameter(torch.zeros(n_visible))
        self.h_bias = nn.Parameter(torch.zeros(n_hidden))
        
        # Hyperparameters
        self.k = k
        self.learning_rate = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        
        # For monitoring
        self.reconstruction_error_history = []
        
        # For momentum updates
        self.W_momentum = torch.zeros_like(self.W)
        self.v_bias_momentum = torch.zeros_like(self.v_bias)
        self.h_bias_momentum = torch.zeros_like(self.h_bias)
    
    def sample_h_given_v(self, v):
        """Sample hidden units given visible units."""
        h_prob = torch.sigmoid(torch.matmul(v, self.W) + self.h_bias)
        h_sample = torch.bernoulli(h_prob)
        return h_prob, h_sample
    
    def sample_v_given_h(self, h):
        """Sample visible units given hidden units."""
        v_prob = torch.sigmoid(torch.matmul(h, self.W.t()) + self.v_bias)
        v_sample = torch.bernoulli(v_prob)
        return v_prob, v_sample
    
    def free_energy(self, v):
        """Calculate the free energy."""
        v_bias_term = torch.matmul(v, self.v_bias)
        hidden_term = torch.sum(torch.log(1 + torch.exp(torch.matmul(v, self.W) + self.h_bias)), dim=1)
        return -v_bias_term - hidden_term
    
    def forward(self, v):
        """Forward pass: reconstruct v."""
        h_prob, h_sample = self.sample_h_given_v(v)
        v_prob, _ = self.sample_v_given_h(h_sample)
        return v_prob
    
    def contrastive_divergence(self, v_data, k=1):
        """Perform contrastive divergence-k training step."""
        # Positive phase
        h_prob, h_sample = self.sample_h_given_v(v_data)
        pos_associations = torch.matmul(v_data.t(), h_prob)
        
        # Negative phase (Gibbs sampling)
        v_sample = v_data.clone()
        for _ in range(k):
            _, h_sample = self.sample_h_given_v(v_sample)
            v_prob, v_sample = self.sample_v_given_h(h_sample)
        
        h_prob_neg, _ = self.sample_h_given_v(v_sample)
        neg_associations = torch.matmul(v_sample.t(), h_prob_neg)
        
        # Calculate gradients
        batch_size = v_data.size(0)
        
        W_grad = (pos_associations - neg_associations) / batch_size
        v_bias_grad = torch.mean(v_data - v_sample, dim=0)
        h_bias_grad = torch.mean(h_prob - h_prob_neg, dim=0)  # Fixed bug here
        
        # L2 weight decay
        W_grad -= self.weight_decay * self.W
        
        return W_grad, v_bias_grad, h_bias_grad, v_sample
    
    def train_batch(self, v_data, optimizer=None):
        """Train on a batch of data."""
        W_grad, v_bias_grad, h_bias_grad, v_recon = self.contrastive_divergence(v_data, self.k)
        
        # Apply momentum updates
        with torch.no_grad():
            self.W_momentum = self.momentum * self.W_momentum + self.learning_rate * W_grad
            self.v_bias_momentum = self.momentum * self.v_bias_momentum + self.learning_rate * v_bias_grad
            self.h_bias_momentum = self.momentum * self.h_bias_momentum + self.learning_rate * h_bias_grad
            
            self.W += self.W_momentum
            self.v_bias += self.v_bias_momentum
            self.h_bias += self.h_bias_momentum
        
        # Compute reconstruction error
        reconstruction_error = torch.mean(torch.sum((v_data - v_recon) ** 2, dim=1))
        self.reconstruction_error_history.append(reconstruction_error.item())
        
        return reconstruction_error.item()
    
    def generate_samples(self, n_samples, initial_v=None, k_steps=1000):
        """Generate samples from the model."""
        if initial_v is None:
            initial_v = torch.rand(n_samples, self.v_bias.size(0))
        
        v_samples = initial_v.clone()
        
        # Run Gibbs sampling for k_steps
        for _ in range(k_steps):
            _, h_samples = self.sample_h_given_v(v_samples)
            v_prob, v_samples = self.sample_v_given_h(h_samples)
        
        return v_samples
    
    def shell_interaction_strength(self):
        """
        Analyze the weight matrix to determine shell interaction strengths.
        Returns a matrix where element (i,j) represents the strength of 
        interaction between shell i and shell j.
        """
        n_shells = self.W.shape[0] // 2  # Divide by 2 because we have real and imaginary parts
        interaction_matrix = torch.zeros((n_shells, n_shells))
        
        # For each pair of shells, compute the average absolute connection strength
        for i in range(n_shells):
            for j in range(n_shells):
                # Get weights connecting shell i (real and imag) to all hidden units
                shell_i_real_weights = self.W[2*i, :]
                shell_i_imag_weights = self.W[2*i+1, :]
                
                # Get weights connecting shell j (real and imag) to all hidden units
                shell_j_real_weights = self.W[2*j, :]
                shell_j_imag_weights = self.W[2*j+1, :]
                
                # Compute correlation between weight patterns
                corr_real_real = torch.mean(torch.abs(shell_i_real_weights * shell_j_real_weights))
                corr_real_imag = torch.mean(torch.abs(shell_i_real_weights * shell_j_imag_weights))
                corr_imag_real = torch.mean(torch.abs(shell_i_imag_weights * shell_j_real_weights))
                corr_imag_imag = torch.mean(torch.abs(shell_i_imag_weights * shell_j_imag_weights))
                
                # Average correlation across real/imaginary components
                interaction_matrix[i, j] = (corr_real_real + corr_real_imag + 
                                            corr_imag_real + corr_imag_imag) / 4
        
        return interaction_matrix


def load_complex_data(file_path):
    """
    Load complex spectral data with proper handling for scientific notation and minus signs.
    
    Args:
        file_path: Path to the data file
        
    Returns:
        wavenumbers: Array of wavenumbers (shells)
        velocities: Complex array of velocities
    """
    try:
        with open(file_path, 'r') as f:
            content = f.read().strip()
        
        # Split into lines and process each line
        lines = content.split('\n')
        data_lines = []
        
        for line in lines:
            if not line.strip():
                continue
                
            # Use regex to properly split scientific notation numbers
            # This pattern matches scientific notation including negative numbers
            pattern = r'([+-]?\d+\.?\d*[eE][+-]?\d+|[+-]?\d+\.?\d*)'
            numbers = re.findall(pattern, line)
            
            if len(numbers) >= 3:
                data_lines.append([float(numbers[0]), float(numbers[1]), float(numbers[2])])
            else:
                print(f"Warning: Could not parse line: {line}")
        
        if not data_lines:
            raise ValueError("No valid data lines found")
        
        data = np.array(data_lines)
        wavenumbers = data[:, 0]
        velocities = data[:, 1] + 1j * data[:, 2]
        
        return wavenumbers, velocities
        
    except Exception as e:
        print(f"Error loading file {file_path}: {e}")
        raise


def load_multiple_snapshots(directory, pattern="spectr_complex\..*\.txt", max_files=None):
    """
    Load multiple snapshots from a directory matching the given pattern.
    
    Args:
        directory: Directory containing snapshot files
        pattern: Regex pattern to match files
        max_files: Maximum number of files to load (None for all)
        
    Returns:
        wavenumbers: Array of wavenumbers (shells)
        all_velocities: Array of complex velocities [n_snapshots, n_shells]
    """
    # Get list of files matching the pattern
    file_pattern = re.compile(pattern)
    all_files = [f for f in os.listdir(directory) if file_pattern.match(f)]
    all_files.sort()  # Ensure chronological order (lowest to highest, e.g., 1 to 50000)
    
    if max_files is not None:
        all_files = all_files[-max_files:]  # Take the last max_files files (from the end)
    
    print(f"Found {len(all_files)} snapshot files.")
    
    if not all_files:
        raise ValueError(f"No files found matching pattern {pattern} in directory {directory}")
    
    # Load the first file to get dimensions
    wavenumbers, _ = load_complex_data(os.path.join(directory, all_files[0]))
    n_shells = len(wavenumbers)
    
    # Initialize array for all velocities
    all_velocities = np.zeros((len(all_files), n_shells), dtype=complex)
    
    # Load data from all files
    for i, file_name in enumerate(tqdm(all_files, desc="Loading files")):
        try:
            _, velocities = load_complex_data(os.path.join(directory, file_name))
            if len(velocities) != n_shells:
                print(f"Warning: File {file_name} has {len(velocities)} shells, expected {n_shells}")
                continue
            all_velocities[i] = velocities
        except Exception as e:
            print(f"Error loading file {file_name}: {e}")
            continue
    
    return wavenumbers, all_velocities


def prepare_data_for_rbm(velocities, normalize=True):
    """
    Prepare complex velocity data for RBM training.
    
    Args:
        velocities: Complex array of shape [n_samples, n_shells]
        normalize: Whether to normalize the data
        
    Returns:
        data: PyTorch tensor of shape [n_samples, 2*n_shells]
        scalers: Tuple of (scaler_real, scaler_imag) if normalize=True, else None
    """
    # Split complex data into real and imaginary parts
    real_parts = np.real(velocities)
    imag_parts = np.imag(velocities)
    
    scalers = None
    if normalize:
        # Normalize real and imaginary parts separately
        scaler_real = StandardScaler()
        scaler_imag = StandardScaler()
        
        real_parts = scaler_real.fit_transform(real_parts)
        imag_parts = scaler_imag.fit_transform(imag_parts)
        
        scalers = (scaler_real, scaler_imag)
    
    # Interleave real and imaginary parts: [real_1, imag_1, real_2, imag_2, ...]
    data = np.zeros((velocities.shape[0], 2 * velocities.shape[1]))
    data[:, 0::2] = real_parts
    data[:, 1::2] = imag_parts
    
    # Convert to PyTorch tensor
    data = torch.tensor(data, dtype=torch.float32)
    
    return data, scalers


def train_rbm(rbm, train_data, batch_size=100, num_epochs=100, early_stopping=False, 
              patience=10, min_improvement=0.001,k=1,normalize=True):
    """
    Train the RBM on the provided data.
    
    Args:
        rbm: The RBM model
        train_data: Training data tensor
        batch_size: Batch size for training
        num_epochs: Maximum number of training epochs
        early_stopping: Whether to use early stopping
        patience: Number of epochs with no improvement before stopping
        min_improvement: Minimum relative improvement to reset patience counter
        
    Returns:
        rbm: The trained RBM model
    """
    # Create a DataLoader to iterate over the training data in shuffled mini-batches
    train_loader = DataLoader(
        TensorDataset(train_data),
        batch_size=batch_size,
        shuffle=True
    )
 
    best_error = float('inf')
    epochs_no_improve = 0
    
    epoch_errors = []
    
    for epoch in range(num_epochs):
        epoch_error = 0
        num_batches = 0
        
        for batch_idx, (data,) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")):
            error = rbm.train_batch(data)
            epoch_error += error
            num_batches += 1
       
        avg_epoch_error = epoch_error / num_batches
        epoch_errors.append(avg_epoch_error)
        
        #print(f"Epoch {epoch+1}, Reconstruction Error: {avg_epoch_error:.6f}")
        
        # Early stopping logic
        if early_stopping:
            if best_error * (1 - min_improvement) > avg_epoch_error:
                best_error = avg_epoch_error
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
                
            if epochs_no_improve >= patience:
                print(f"Early stopping at epoch {epoch+1}")
                break
    
    # Plot training error
    plt.figure(figsize=(10, 5))
    plt.plot(epoch_errors)
    plt.title('RBM Training Progress')
    plt.xlabel('Epoch')
    plt.ylabel('Reconstruction Error')
    plt.grid(True)
    normalization_status = "normalized" if normalize else "not_normalized"
    save_path = "/home/vale/SABRA/params_bin/RBM_figures"  # Specify the directory to save the figure
    os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist
    plt.savefig(os.path.join(save_path, f"rbm_training_progress_hidden{rbm.h_bias.shape[0]}_visible{rbm.v_bias.shape[0]}_wd{rbm.weight_decay}_batch{batch_size}_k{rbm.k}_{normalization_status}.png"))
    plt.close()
    
    return rbm


def visualize_shell_interactions(rbm, wavenumbers,batch_size=100,k=1):
    """
    Visualize the strength of interactions between shells.
    
    Args:
        rbm: Trained RBM model
        wavenumbers: Array of wavenumbers corresponding to each shell
    """
    interaction_matrix = rbm.shell_interaction_strength().detach().numpy()
    
    # Create log-scale labels for the shells
    shell_labels = [f"{wn:.1e}" for wn in wavenumbers]
    
    plt.figure(figsize=(10, 8))
    plt.imshow(interaction_matrix, cmap='viridis', interpolation='nearest')
    plt.colorbar(label='Interaction Strength')
    plt.title('Shell Interaction Strength Matrix')
    plt.xlabel('Shell Wavenumber')
    plt.ylabel('Shell Wavenumber')
    plt.xticks(np.arange(len(wavenumbers)), shell_labels, rotation=45)
    plt.yticks(np.arange(len(wavenumbers)), shell_labels)
    plt.tight_layout()
    save_path = "/home/vale/SABRA/params_bin/RBM_figures"  # Specify the directory to save the figure
    os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist
    plt.savefig(os.path.join(save_path, f"shell_interaction_matrix_hidden{rbm.h_bias.shape[0]}_visible{rbm.v_bias.shape[0]}_batch{batch_size}_wd{rbm.weight_decay}_k{rbm.k}.png"))
    plt.close()
    
    # Plot the diagonal strength (self-interaction) and interaction with neighbors
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(np.log10(wavenumbers), np.diag(interaction_matrix), 'o-')
    plt.title('Self-Interaction Strength')
    plt.xlabel('log10(Wavenumber)')
    plt.ylabel('Interaction Strength')
    plt.grid(True)
    
    plt.subplot(1, 2, 2)
    # Plot interaction with +1 neighbor where available
    neighbor_interactions = [interaction_matrix[i, i+1] for i in range(len(wavenumbers)-1)]
    plt.plot(np.log10(wavenumbers[:-1]), neighbor_interactions, 'o-', label='+1 Neighbor')
    
    # Plot interaction with +2 neighbor where available (for triadic interactions)
    if len(wavenumbers) > 2:
        triadic_interactions = [interaction_matrix[i, i+2] for i in range(len(wavenumbers)-2)]
        plt.plot(np.log10(wavenumbers[:-2]), triadic_interactions, 'o-', label='+2 Neighbor')
    
    plt.title('Neighbor Interaction Strength')
    plt.xlabel('log10(Wavenumber)')
    plt.ylabel('Interaction Strength')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    save_path = "/home/vale/SABRA/params_bin/RBM_figures"  # Specify the directory to save the figure
    os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist
    plt.savefig(os.path.join(save_path, f"shell_neighbor_interactions_hidden{rbm.h_bias.shape[0]}_epochs{rbm.reconstruction_error_history.__len__()}.png"))
    plt.close()
    
    # Also save the interaction matrix as a numpy array for further analysis
    np.save("shell_interaction_matrix.npy", interaction_matrix)


def analyze_generated_samples(rbm, real_data, wavenumbers, n_samples=1000,batch_size=100,k=1):
    """
    Generate samples from the RBM and compare their statistics with real data.
    
    Args:
        rbm: Trained RBM model
        real_data: Real data tensor used for training
        wavenumbers: Array of wavenumbers
        n_samples: Number of samples to generate
    """
    # Generate samples
    n_samples = min(n_samples, real_data.shape[0])  # Don't exceed available data
    generated_samples = rbm.generate_samples(n_samples, initial_v=real_data[:n_samples]).detach().numpy()
    real_data_np = real_data.numpy()
    
    # Convert back to complex form
    n_shells = len(wavenumbers)
    
    real_complex = np.zeros((real_data_np.shape[0], n_shells), dtype=complex)
    gen_complex = np.zeros((n_samples, n_shells), dtype=complex)
    
    # Extract real and imaginary parts
    for i in range(n_shells):
        real_complex[:, i] = real_data_np[:, 2*i] + 1j * real_data_np[:, 2*i+1]
        gen_complex[:, i] = generated_samples[:, 2*i] + 1j * generated_samples[:, 2*i+1]
    
    # Compare energy spectrum
    real_energy = np.mean(np.abs(real_complex)**2, axis=0)
    gen_energy = np.mean(np.abs(gen_complex)**2, axis=0)
    
    plt.figure(figsize=(10, 6))
    plt.loglog(wavenumbers, real_energy, 'o-', label='Real Data')
    plt.loglog(wavenumbers, gen_energy, 's-', label='Generated Data')
    plt.title('Energy Spectrum Comparison')
    plt.xlabel('Wavenumber')
    plt.ylabel('Energy')
    plt.legend()
    plt.grid(True)
    save_path = "/home/vale/SABRA/params_bin/RBM_figures"  # Specify the directory to save the figure
    os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist
    plt.savefig(os.path.join(save_path, f"energy_spectrum_comparison_hidden{rbm.h_bias.shape[0]}_visible{rbm.v_bias.shape[0]}_wd{rbm.weight_decay}_batch{batch_size}_k{rbm.k}.png"))
    plt.close()
    
    # Compare PDFs of real and imaginary parts for a few selected shells
    shells_to_plot = [0, len(wavenumbers)//2, -1]  # First, middle, last
    
    fig, axes = plt.subplots(len(shells_to_plot), 2, figsize=(12, 4*len(shells_to_plot)))
    if len(shells_to_plot) == 1:
        axes = axes.reshape(1, -1)
    
    for i, shell_idx in enumerate(shells_to_plot):
        # Real part PDF
        axes[i, 0].hist(real_data_np[:, 2*shell_idx], bins=30, alpha=0.5, density=True, label='Real Data')
        axes[i, 0].hist(generated_samples[:, 2*shell_idx], bins=30, alpha=0.5, density=True, label='Generated')
        axes[i, 0].set_title(f'Shell k={wavenumbers[shell_idx]:.1e} - Real Part')
        axes[i, 0].legend()
        axes[i, 0].grid(True)
        
        # Imaginary part PDF
        axes[i, 1].hist(real_data_np[:, 2*shell_idx+1], bins=30, alpha=0.5, density=True, label='Real Data')
        axes[i, 1].hist(generated_samples[:, 2*shell_idx+1], bins=30, alpha=0.5, density=True, label='Generated')
        axes[i, 1].set_title(f'Shell k={wavenumbers[shell_idx]:.1e} - Imaginary Part')
        axes[i, 1].legend()
        axes[i, 1].grid(True)
    
    plt.tight_layout()
    save_path = "/home/vale/SABRA/params_bin/RBM_figures"  # Specify the directory to save the figure
    os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist
    plt.savefig(os.path.join(save_path, f"velocity_distribution_comparison_hidden{rbm.h_bias.shape[0]}_visible{rbm.v_bias.shape[0]}_wd{rbm.weight_decay}_batch{batch_size}_k{rbm.k}.png"))
    plt.close()



def main():
    """Main function to run the RBM analysis on SABRA model data."""
    # Configuration
    data_dir = "/home/vale/SABRA/params_bin/sim_nn10_LAM_2_nu1e8/Spectra_complex"
    max_files = 10000
    batch_size = 100
    num_epochs = 200
    n_hidden = 10
    normalize = True
    weight_decay = 1e-4
    k=3
    
    # Load multiple snapshots
    print(f"Loading data from directory: {data_dir}")
    try:
        wavenumbers, velocities = load_multiple_snapshots(data_dir, pattern="spectr_complex.*\.txt", max_files=max_files)
    except Exception as e:
        print(f"Error loading data: {e}")
        return
    
    print(f"Loaded {velocities.shape[0]} snapshots with {velocities.shape[1]} shells each.")
    
    # Prepare data for RBM
    data, scalers = prepare_data_for_rbm(velocities, normalize=normalize)
    print(f"scalers: {scalers}")
    
    # Initialize RBM
    n_visible = 2 * len(wavenumbers)
    print(f"Initializing RBM with {n_visible} visible units and {n_hidden} hidden units")
    
    # Create RBM
    rbm = ComplexRBM(n_visible=n_visible, 
                     n_hidden=n_hidden, 
                     k=k,
                     learning_rate=0.01, 
                     momentum=0.5, 
                     weight_decay=weight_decay)
    
    # Train the RBM
    print("Training RBM...")

    #rbm.load_state_dict(torch.load("trained_rbm.pt"))

    rbm = train_rbm(rbm, data, batch_size=batch_size, num_epochs=num_epochs,k=rbm.k,normalize=normalize,)
    
    # Save trained model
    save_path = "RBMs"
    os.makedirs(save_path, exist_ok=True)  # Create the directory if it doesn't exist
    model_filename = f"trained_rbm_hidden{rbm.h_bias.shape[0]}_visible{rbm.v_bias.shape[0]}_wd{rbm.weight_decay}_batch{batch_size}_k{rbm.k}_epochs{num_epochs}.pt"
    torch.save(rbm.state_dict(), os.path.join(save_path, model_filename))
    print(f"RBM training complete. Saved model to '{os.path.join(save_path, model_filename)}'")
    
    # Visualize shell interactions
    print("Visualizing shell interactions...")
    visualize_shell_interactions(rbm, wavenumbers, batch_size=batch_size,k=rbm.k)
    
    # Generate and analyze samples
    print("Generating and analyzing samples...")
    analyze_generated_samples(rbm, data, wavenumbers, n_samples=100,batch_size=batch_size,k=rbm.k)
    
    # Additional analysis of the learned features
    with torch.no_grad():
        h_probs = torch.sigmoid(torch.matmul(data, rbm.W) + rbm.h_bias)
    
    # Save feature representations for further analysis
    np.save("feature_representations.npy", h_probs.numpy())
    
    print("Analysis complete. Results saved to disk.")


if __name__ == "__main__":
    main()

Loading data from directory: /home/vale/SABRA/params_bin/sim_nn10_LAM_2_nu1e8/Spectra_complex
Found 10000 snapshot files.


Loading files: 100%|██████████| 10000/10000 [00:09<00:00, 1047.48it/s]


Loaded 10000 snapshots with 10 shells each.
scalers: (StandardScaler(), StandardScaler())
Initializing RBM with 20 visible units and 10 hidden units
Training RBM...


Epoch 1/200: 100%|██████████| 100/100 [00:00<00:00, 387.47it/s]
Epoch 2/200: 100%|██████████| 100/100 [00:00<00:00, 556.29it/s]
Epoch 3/200: 100%|██████████| 100/100 [00:00<00:00, 476.00it/s]
Epoch 4/200: 100%|██████████| 100/100 [00:00<00:00, 486.83it/s]
Epoch 5/200: 100%|██████████| 100/100 [00:00<00:00, 391.19it/s]
Epoch 6/200: 100%|██████████| 100/100 [00:00<00:00, 518.62it/s]
Epoch 7/200: 100%|██████████| 100/100 [00:00<00:00, 574.71it/s]
Epoch 8/200: 100%|██████████| 100/100 [00:00<00:00, 503.97it/s]
Epoch 9/200: 100%|██████████| 100/100 [00:00<00:00, 536.83it/s]
Epoch 10/200: 100%|██████████| 100/100 [00:00<00:00, 548.57it/s]
Epoch 11/200: 100%|██████████| 100/100 [00:00<00:00, 431.09it/s]
Epoch 12/200: 100%|██████████| 100/100 [00:00<00:00, 433.26it/s]
Epoch 13/200: 100%|██████████| 100/100 [00:00<00:00, 566.49it/s]
Epoch 14/200: 100%|██████████| 100/100 [00:00<00:00, 520.58it/s]
Epoch 15/200: 100%|██████████| 100/100 [00:00<00:00, 486.64it/s]
Epoch 16/200: 100%|██████████| 100

KeyboardInterrupt: 

In [None]:
data_dir = "/home/vale/SABRA/params_bin/sim_nn20_LAM_2_nu1e8/Spectra_complex"
max_files = 10000
batch_size = 100
num_epochs = 100
wavenumbers, velocities = load_multiple_snapshots(data_dir, pattern="spectr_complex.*\.txt", max_files=max_files)
data, scalers = prepare_data_for_rbm(velocities, normalize=True)
# Recreate the model with the same architecture as used for training
n_visible = 40  # or whatever was used
n_hidden = 200  # or whatever was used

rbm = ComplexRBM(n_visible=n_visible, n_hidden=n_hidden)
rbm.load_state_dict(torch.load("trained_rbm.pt"))

# Now call the analysis function
analyze_generated_samples(rbm, data, wavenumbers, n_samples=1000)
# ...existing code...