In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import List, Tuple, Optional

class BatchNormalizationLayer(nn.Module):
    """
    Custom Batch Normalization Layer as described in the paper.
    Implements both training and inference modes with moving averages.
    """
    def __init__(self, input_size: int, momentum: float = 0.1, eps: float = 1e-5):
        super(BatchNormalizationLayer, self).__init__()
        self.input_size = input_size
        self.momentum = momentum  # κ in the paper
        self.eps = eps  # ϵ in the paper
        
        # Learnable parameters γ and β
        self.gamma = nn.Parameter(torch.ones(input_size))
        self.beta = nn.Parameter(torch.zeros(input_size))
        
        # Moving averages for inference
        self.register_buffer('running_mean', torch.zeros(input_size))
        self.register_buffer('running_var', torch.ones(input_size))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.training:
            # Training phase: use batch statistics
            batch_mean = x.mean(dim=0)
            batch_var = x.var(dim=0, unbiased=False)
            
            # Update moving averages
            self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * batch_mean
            self.running_var = self.momentum * self.running_var + (1 - self.momentum) * batch_var
            
            # Normalize using batch statistics
            x_normalized = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
        else:
            # Inference phase: use moving averages
            x_normalized = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        
        # Apply scale and shift
        return self.gamma * x_normalized + self.beta


class MultiFaultDiagnosisNN(nn.Module):
    """
    Multi-fault diagnosis neural network as described in Fig. 2(a) of the paper.
    Uses binary cross-entropy loss for multi-label classification.
    """
    def __init__(self, input_size: int, num_faults: int, hidden_layers: List[int] = [128, 64, 32]):
        super(MultiFaultDiagnosisNN, self).__init__()
        self.input_size = input_size
        self.num_faults = num_faults
        
        # Batch normalization as first layer
        self.batch_norm = BatchNormalizationLayer(input_size)
        
        # Build fully connected layers
        layers = []
        prev_size = input_size
        
        for hidden_size in hidden_layers:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            prev_size = hidden_size
        
        # Output layer with sigmoid activation
        layers.append(nn.Linear(prev_size, num_faults))
        
        self.fc_layers = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply batch normalization first
        x = self.batch_norm(x)
        
        # Pass through FC layers
        x = self.fc_layers(x)
        
        # Apply sigmoid activation for multi-label classification
        return torch.sigmoid(x)
    
    def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
        """Online inference with threshold decision"""
        self.eval()
        with torch.no_grad():
            output = self.forward(x)
            predictions = (output > threshold).float()
        return predictions


class SeverityDiagnosisNN(nn.Module):
    """
    Severity diagnosis neural network for individual fault types.
    Uses categorical cross-entropy for multi-class classification.
    """
    def __init__(self, input_size: int, num_severity_levels: int = 3, 
                 hidden_layers: List[int] = [128, 64, 32]):
        super(SeverityDiagnosisNN, self).__init__()
        self.input_size = input_size
        self.num_severity_levels = num_severity_levels
        
        # Batch normalization as first layer
        self.batch_norm = BatchNormalizationLayer(input_size)
        
        # Build fully connected layers
        layers = []
        prev_size = input_size
        
        for hidden_size in hidden_layers:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            prev_size = hidden_size
        
        # Output layer
        layers.append(nn.Linear(prev_size, num_severity_levels))
        
        self.fc_layers = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply batch normalization first
        x = self.batch_norm(x)
        
        # Pass through FC layers
        x = self.fc_layers(x)
        
        # Apply softmax for multi-class classification
        return F.softmax(x, dim=1)
    
    def predict(self, x: torch.Tensor) -> torch.Tensor:
        """Online inference with argmax decision"""
        self.eval()
        with torch.no_grad():
            output = self.forward(x)
            predictions = torch.argmax(output, dim=1)
        return predictions


class SeparatedStructure(nn.Module):
    """
    Separated structure for fault and severity diagnosis.
    Contains one fault diagnosis NN and NF severity diagnosis NNs.
    """
    def __init__(self, input_size: int, num_faults: int, num_severity_levels: int = 3,
                 hidden_layers: List[int] = [128, 64, 32]):
        super(SeparatedStructure, self).__init__()
        self.num_faults = num_faults
        self.num_severity_levels = num_severity_levels
        
        # Fault diagnosis network
        self.fault_diagnosis = MultiFaultDiagnosisNN(input_size, num_faults, hidden_layers)
        
        # Severity diagnosis networks (one for each fault type)
        self.severity_networks = nn.ModuleList([
            SeverityDiagnosisNN(input_size, num_severity_levels, hidden_layers)
            for _ in range(num_faults)
        ])
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        # Diagnose faults
        fault_predictions = self.fault_diagnosis(x)
        
        # Diagnose severity for each fault type
        severity_predictions = []
        for i, severity_net in enumerate(self.severity_networks):
            severity_pred = severity_net(x)
            severity_predictions.append(severity_pred)
        
        return fault_predictions, severity_predictions
    
    def predict(self, x: torch.Tensor, fault_threshold: float = 0.5) -> Tuple[torch.Tensor, List[torch.Tensor]]:
        """Combined prediction for faults and severities"""
        self.eval()
        with torch.no_grad():
            fault_pred = self.fault_diagnosis.predict(x, fault_threshold)
            severity_preds = []
            
            for severity_net in self.severity_networks:
                severity_pred = severity_net.predict(x)
                severity_preds.append(severity_pred)
        
        return fault_pred, severity_preds


class JointStructure(nn.Module):
    """
    Joint structure for simultaneous fault and severity diagnosis.
    Uses a single NN for both fault detection and severity classification.
    """
    def __init__(self, input_size: int, num_faults: int, num_severity_levels: int = 3,
                 hidden_layers: List[int] = [128, 64, 32]):
        super(JointStructure, self).__init__()
        self.input_size = input_size
        self.num_faults = num_faults
        self.num_severity_levels = num_severity_levels
        self.output_size = num_faults * num_severity_levels
        
        # Batch normalization as first layer
        self.batch_norm = BatchNormalizationLayer(input_size)
        
        # Build fully connected layers
        layers = []
        prev_size = input_size
        
        for hidden_size in hidden_layers:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            prev_size = hidden_size
        
        # Output layer for joint fault-severity classification
        layers.append(nn.Linear(prev_size, self.output_size))
        
        self.fc_layers = nn.Sequential(*layers)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Apply batch normalization first
        x = self.batch_norm(x)
        
        # Pass through FC layers
        x = self.fc_layers(x)
        
        # Apply sigmoid activation for multi-label classification
        return torch.sigmoid(x)
    
    def predict(self, x: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
        """Online inference with threshold decision"""
        self.eval()
        with torch.no_grad():
            output = self.forward(x)
            predictions = (output > threshold).float()
        return predictions
    
    def decode_predictions(self, predictions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Decode joint predictions into fault and severity predictions.
        Returns: (fault_predictions, severity_predictions)
        """
        batch_size = predictions.shape[0]
        fault_predictions = torch.zeros(batch_size, self.num_faults)
        severity_predictions = torch.zeros(batch_size, self.num_faults)
        
        for i in range(self.num_faults):
            start_idx = i * self.num_severity_levels
            end_idx = start_idx + self.num_severity_levels
            
            # Check if any severity level is predicted for this fault
            fault_severity_preds = predictions[:, start_idx:end_idx]
            fault_predictions[:, i] = torch.any(fault_severity_preds > 0.5, dim=1).float()
            
            # Get the highest severity level predicted
            severity_predictions[:, i] = torch.argmax(fault_severity_preds, dim=1).float()
        
        return fault_predictions, severity_predictions


class FaultDiagnosisTrainer:
    """
    Training utilities for fault diagnosis networks.
    """
    def __init__(self, model: nn.Module, learning_rate: float = 0.001):
        self.model = model
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
    def train_multi_fault(self, train_loader, num_epochs: int = 100):
        """Train multi-fault diagnosis network"""
        self.model.train()
        criterion = nn.BCELoss()
        
        for epoch in range(num_epochs):
            total_loss = 0.0
            for batch_idx, (data, targets) in enumerate(train_loader):
                self.optimizer.zero_grad()
                
                outputs = self.model(data)
                loss = criterion(outputs, targets.float())
                
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
            
            if (epoch + 1) % 10 == 0:
                avg_loss = total_loss / len(train_loader)
                print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
    
    def train_severity(self, train_loader, num_epochs: int = 100):
        """Train severity diagnosis network"""
        self.model.train()
        criterion = nn.CrossEntropyLoss()
        
        for epoch in range(num_epochs):
            total_loss = 0.0
            for batch_idx, (data, targets) in enumerate(train_loader):
                self.optimizer.zero_grad()
                
                outputs = self.model(data)
                loss = criterion(outputs, targets.long())
                
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
            
            if (epoch + 1) % 10 == 0:
                avg_loss = total_loss / len(train_loader)
                print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')
    
    def train_joint(self, train_loader, num_epochs: int = 100):
        """Train joint fault-severity diagnosis network"""
        self.model.train()
        criterion = nn.BCELoss()
        
        for epoch in range(num_epochs):
            total_loss = 0.0
            for batch_idx, (data, targets) in enumerate(train_loader):
                self.optimizer.zero_grad()
                
                outputs = self.model(data)
                loss = criterion(outputs, targets.float())
                
                loss.backward()
                self.optimizer.step()
                
                total_loss += loss.item()
            
            if (epoch + 1) % 10 == 0:
                avg_loss = total_loss / len(train_loader)
                print(f'Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}')


# Example usage and demonstration
if __name__ == "__main__":
    # Example parameters
    input_size = 50  # Number of KPIs
    num_faults = 3   # ERP, ED, EU
    num_severity_levels = 3
    batch_size = 32
    
    # Generate sample data
    X = torch.randn(1000, input_size)
    y_faults = torch.randint(0, 2, (1000, num_faults))  # Multi-label fault data
    y_joint = torch.randint(0, 2, (1000, num_faults * num_severity_levels))  # Joint labels
    
    print("=== Multi-Fault Diagnosis Network ===")
    fault_model = MultiFaultDiagnosisNN(input_size, num_faults)
    print(fault_model)
    
    # Test forward pass
    with torch.no_grad():
        sample_input = torch.randn(1, input_size)
        fault_output = fault_model(sample_input)
        fault_prediction = fault_model.predict(sample_input)
        print(f"Fault probabilities: {fault_output}")
        print(f"Fault predictions: {fault_prediction}")
    
    print("\n=== Severity Diagnosis Network ===")
    severity_model = SeverityDiagnosisNN(input_size, num_severity_levels)
    print(severity_model)
    
    print("\n=== Separated Structure ===")
    separated_model = SeparatedStructure(input_size, num_faults, num_severity_levels)
    print(f"Number of parameters: {sum(p.numel() for p in separated_model.parameters())}")
    
    print("\n=== Joint Structure ===")
    joint_model = JointStructure(input_size, num_faults, num_severity_levels)
    print(f"Number of parameters: {sum(p.numel() for p in joint_model.parameters())}")
    
    # Test joint structure
    with torch.no_grad():
        joint_output = joint_model(sample_input)
        joint_prediction = joint_model.predict(sample_input)
        fault_pred, severity_pred = joint_model.decode_predictions(joint_prediction)
        print(f"Joint output shape: {joint_output.shape}")
        print(f"Decoded fault predictions: {fault_pred}")
        print(f"Decoded severity predictions: {severity_pred}")

=== Multi-Fault Diagnosis Network ===
MultiFaultDiagnosisNN(
  (batch_norm): BatchNormalizationLayer()
  (fc_layers): Sequential(
    (0): Linear(in_features=50, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): Linear(in_features=32, out_features=3, bias=True)
  )
)
Fault probabilities: tensor([[0.4689, 0.4588, 0.4534]])
Fault predictions: tensor([[0., 0., 0.]])

=== Severity Diagnosis Network ===
SeverityDiagnosisNN(
  (batch_norm): BatchNormalizationLayer()
  (fc_layers): Sequential(
    (0): Linear(in_features=50, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=64, bias=True)
    (3): ReLU()
    (4): Linear(in_features=64, out_features=32, bias=True)
    (5): ReLU()
    (6): Linear(in_features=32, out_features=3, bias=True)
  )
)

=== Separated Structure ===
Number of parameters: 68252