# Probabilistic Neural Network
## Based on Fabre & Challet

In [1]:
import machine_learning as ml
import preprocessing
import numpy as np
from matplotlib import pyplot as plt
from sklearn.svm import OneClassSVM
from sklearn.preprocessing import StandardScaler
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import roc_auc_score, average_precision_score, fbeta_score, confusion_matrix, ConfusionMatrixDisplay

# Train PNN

In [2]:
epochs = 1000
learning_rate = 1e-3
patience = 100

In [3]:
def train_pnn(model, train_loader, val_loader, epochs=1000, lr=1e-3, patience=100, device='cpu'):
    """
    Train the Probabilistic Neural Network
    
    Args:
        model: PNN model
        train_loader: Training data loader
        val_loader: Validation data loader
        epochs: Maximum number of epochs
        lr: Learning rate
        patience: Early stopping patience
        device: 'cpu' or 'cuda'
    """
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = ml.SkewedGaussianNLL()
    
    best_val_loss = float('inf')
    patience_counter = 0
    
    for epoch in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            
            optimizer.zero_grad()
            mu, sigma, alpha = model(batch_x)
            loss = criterion(batch_y, mu, sigma, alpha)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        
        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch_x, batch_y in val_loader:
                batch_x, batch_y = batch_x.to(device), batch_y.to(device)
                mu, sigma, alpha = model(batch_x)
                loss = criterion(batch_y, mu, sigma, alpha)
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        
        # Early stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            # Save best model
            torch.save(model.state_dict(), 'best_pnn_model.pth')
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            print(f'Early stopping at epoch {epoch}')
            break
        
        if epoch % 10 == 0:
            print(f'Epoch {epoch}: Train Loss = {train_loss:.6f}, Val Loss = {val_loss:.6f}')
    
    # Load best model
    model.load_state_dict(torch.load('best_pnn_model.pth'))
    return model