# VAE-LSTM Model Training for Predictive Maintenance

This notebook implements a Variational Autoencoder (VAE) combined with LSTM for anomaly detection in predictive maintenance.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set style for plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 1. VAE-LSTM Model Architecture

In [None]:
class VAE_LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim, sequence_length):
        super(VAE_LSTM, self).__init__()
        
        self.sequence_length = sequence_length
        self.latent_dim = latent_dim
        
        # Encoder LSTM
        self.encoder_lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        
        # Latent space
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.decoder_input = nn.Linear(latent_dim, hidden_dim)
        self.decoder_lstm = nn.LSTM(hidden_dim, hidden_dim, batch_first=True)
        self.decoder_output = nn.Linear(hidden_dim, input_dim)
        
    def encode(self, x):
        _, (h_n, _) = self.encoder_lstm(x)
        h_n = h_n.squeeze(0)  # Remove batch dimension
        mu = self.fc_mu(h_n)
        log_var = self.fc_var(h_n)
        return mu, log_var
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        # Repeat latent vector for sequence length
        z = z.unsqueeze(1).repeat(1, self.sequence_length, 1)
        
        # Decode
        h = self.decoder_input(z)
        output, _ = self.decoder_lstm(h)
        reconstruction = self.decoder_output(output)
        return reconstruction
    
    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        reconstruction = self.decode(z)
        return reconstruction, mu, log_var
    
    def loss_function(self, reconstruction, x, mu, log_var):
        # Reconstruction loss
        recon_loss = nn.MSELoss(reduction='sum')(reconstruction, x)
        
        # KL divergence
        kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        
        return recon_loss + kl_loss

## 2. Dataset Class

In [None]:
class TimeSeriesDataset(Dataset):
    def __init__(self, data, sequence_length):
        self.data = torch.FloatTensor(data)
        self.sequence_length = sequence_length
    
    def __len__(self):
        return len(self.data) - self.sequence_length + 1
    
    def __getitem__(self, idx):
        return self.data[idx:idx + self.sequence_length]

## 3. Training Functions

In [None]:
def train_vae_lstm(model, train_loader, optimizer, device, num_epochs=100):
    model.train()
    train_losses = []
    
    for epoch in range(num_epochs):
        epoch_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            
            reconstruction, mu, log_var = model(batch)
            loss = model.loss_function(reconstruction, batch, mu, log_var)
            
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}')
    
    return train_losses

def compute_reconstruction_error(model, data_loader, device):
    model.eval()
    errors = []
    
    with torch.no_grad():
        for batch in data_loader:
            batch = batch.to(device)
            reconstruction, _, _ = model(batch)
            error = torch.mean((reconstruction - batch) ** 2, dim=[1, 2])
            errors.extend(error.cpu().numpy())
    
    return np.array(errors)

## 4. Anomaly Detection Functions

In [None]:
def set_anomaly_threshold(errors, contamination=0.1):
    """Set anomaly threshold based on training errors"""
    threshold = np.percentile(errors, (1 - contamination) * 100)
    return threshold

def detect_anomalies(errors, threshold):
    """Detect anomalies based on reconstruction error threshold"""
    return errors > threshold

def evaluate_anomaly_detection(true_labels, predicted_labels):
    """Evaluate anomaly detection performance"""
    from sklearn.metrics import classification_report, confusion_matrix
    
    print("Classification Report:")
    print(classification_report(true_labels, predicted_labels))
    
    cm = confusion_matrix(true_labels, predicted_labels)
    print("\nConfusion Matrix:")
    print(cm)
    
    # Calculate AUC if we have probability scores
    if len(np.unique(true_labels)) == 2:
        auc_score = roc_auc_score(true_labels, predicted_labels)
        print(f"\nAUC Score: {auc_score:.4f}")
    
    return cm

## 5. Visualization Functions

In [None]:
def plot_training_loss(train_losses):
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('VAE-LSTM Training Loss')
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_reconstruction_errors(errors, threshold=None, anomalies=None):
    plt.figure(figsize=(15, 6))
    
    plt.subplot(1, 2, 1)
    plt.hist(errors, bins=50, alpha=0.7, color='blue')
    if threshold:
        plt.axvline(x=threshold, color='red', linestyle='--', label=f'Threshold: {threshold:.4f}')
    plt.xlabel('Reconstruction Error')
    plt.ylabel('Frequency')
    plt.title('Distribution of Reconstruction Errors')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(errors, alpha=0.7, color='blue', label='Reconstruction Error')
    if threshold:
        plt.axhline(y=threshold, color='red', linestyle='--', label=f'Threshold: {threshold:.4f}')
    if anomalies is not None:
        anomaly_indices = np.where(anomalies)[0]
        plt.scatter(anomaly_indices, errors[anomaly_indices], color='red', s=10, label='Anomalies')
    plt.xlabel('Time Step')
    plt.ylabel('Reconstruction Error')
    plt.title('Reconstruction Errors Over Time')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

def plot_anomaly_comparison(original, reconstructed, anomaly_indices, feature_idx=0):
    plt.figure(figsize=(15, 8))
    
    # Plot original vs reconstructed
    plt.subplot(2, 1, 1)
    plt.plot(original[:, feature_idx], label='Original', alpha=0.7)
    plt.plot(reconstructed[:, feature_idx], label='Reconstructed', alpha=0.7)
    plt.scatter(anomaly_indices, original[anomaly_indices, feature_idx], 
               color='red', s=20, label='Anomalies')
    plt.xlabel('Time Step')
    plt.ylabel(f'Feature {feature_idx}')
    plt.title('Original vs Reconstructed Data')
    plt.legend()
    
    # Plot reconstruction error
    plt.subplot(2, 1, 2)
    error = np.mean((original - reconstructed) ** 2, axis=1)
    plt.plot(error, label='Reconstruction Error', color='orange')
    plt.scatter(anomaly_indices, error[anomaly_indices], color='red', s=20, label='Anomalies')
    plt.xlabel('Time Step')
    plt.ylabel('MSE')
    plt.title('Reconstruction Error')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

## 6. Complete Training Pipeline

In [None]:
def train_vae_lstm_pipeline(data, config=None):
    """Complete VAE-LSTM training pipeline"""
    if config is None:
        config = {
            'input_dim': data.shape[1],
            'hidden_dim': 64,
            'latent_dim': 32,
            'sequence_length': 50,
            'batch_size': 32,
            'num_epochs': 100,
            'learning_rate': 1e-3,
            'contamination': 0.1
        }
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create dataset and dataloader
    dataset = TimeSeriesDataset(data, config['sequence_length'])
    train_loader = DataLoader(dataset, batch_size=config['batch_size'], shuffle=True)
    
    # Initialize model
    model = VAE_LSTM(
        config['input_dim'],
        config['hidden_dim'],
        config['latent_dim'],
        config['sequence_length']
    ).to(device)
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=config['learning_rate'])
    
    # Train model
    print("Training VAE-LSTM model...")
    train_losses = train_vae_lstm(model, train_loader, optimizer, device, config['num_epochs'])
    
    # Compute reconstruction errors on training data
    print("Computing reconstruction errors...")
    train_errors = compute_reconstruction_error(model, train_loader, device)
    
    # Set anomaly threshold
    threshold = set_anomaly_threshold(train_errors, config['contamination'])
    print(f"Anomaly threshold set to: {threshold:.4f}")
    
    # Plot training results
    plot_training_loss(train_losses)
    plot_reconstruction_errors(train_errors, threshold)
    
    return model, threshold, train_errors

## 7. Model Saving and Loading

In [None]:
def save_vae_lstm_model(model, threshold, config, filepath):
    """Save trained VAE-LSTM model and parameters"""
    torch.save({
        'model_state_dict': model.state_dict(),
        'threshold': threshold,
        'config': config
    }, filepath)
    print(f"Model saved to {filepath}")

def load_vae_lstm_model(filepath, device):
    """Load trained VAE-LSTM model and parameters"""
    checkpoint = torch.load(filepath, map_location=device)
    
    config = checkpoint['config']
    model = VAE_LSTM(
        config['input_dim'],
        config['hidden_dim'],
        config['latent_dim'],
        config['sequence_length']
    ).to(device)
    
    model.load_state_dict(checkpoint['model_state_dict'])
    threshold = checkpoint['threshold']
    
    print(f"Model loaded from {filepath}")
    return model, threshold, config

## 8. Example Usage

In [None]:
# Example usage (uncomment and modify for your data)
# 
# # Load and preprocess your data
# # data = ... # Your preprocessed time series data (numpy array or pandas DataFrame)
# 
# # Define configuration
# config = {
#     'input_dim': data.shape[1],
#     'hidden_dim': 64,
#     'latent_dim': 32,
#     'sequence_length': 50,
#     'batch_size': 32,
#     'num_epochs': 100,
#     'learning_rate': 1e-3,
#     'contamination': 0.1
# }
# 
# # Train model
# model, threshold, train_errors = train_vae_lstm_pipeline(data, config)
# 
# # Save model
# save_vae_lstm_model(model, threshold, config, 'models/vae_lstm_model.pth')
# 
# # For inference on new data:
# # test_dataset = TimeSeriesDataset(test_data, config['sequence_length'])
# # test_loader = DataLoader(test_dataset, batch_size=config['batch_size'])
# # test_errors = compute_reconstruction_error(model, test_loader, device)
# # anomalies = detect_anomalies(test_errors, threshold)

print("VAE-LSTM training functions defined. Ready to use!")