# <b>ASYNCHRONOUS PUBLISH-SUBSCRIBE ANOMALY DETECTION WITH CHRONOS-2</b>

Firstly, we need to import the necessary libraries and modules for our project. We will be using Chronos-2 for anomaly detection, along with other libraries for data handling and visualization.

In [None]:
import paho.mqtt.client as mqtt

import numpy as np
import torch

from datetime import datetime
from json import loads as json_loads
from chronos import Chronos2Pipeline

In [None]:
class myQueue():
    def __init__(self, max_size:int):
        self.data = []
        self.anomaly = []
        self.timestamp = []
        self.max_size = max_size

    def push(self, data, anomaly, timestamp):
        if len(self.data) >= self.max_size:
            self.data.pop(0), 
            self.anomaly.pop(0)
            self.timestamp.pop(0)
        self.data.append(data)
        self.anomaly.append(anomaly)
        self.timestamp.append(timestamp)

    def getData(self):
        return list(self.data)
    
    def getAnomaly(self):
        return list(self.anomaly)
    
    def getTimestamp(self):
        return list(self.timestamp)
    
    def getAll(self):
        return list(zip(self.timestamp, self.data, self.anomaly))
    
    def __len__(self):
        return len(self.data)
    
    def isFull(self):
        return len(self.data) >= self.max_size

# Chronos-2 Model
Next, we will initialize our Chronos-2 model for anomaly detection. We will load the pre-trained model and set it up for real-time inference on the incoming data from the MQTT client.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pipeline = Chronos2Pipeline.from_pretrained("amazon/chronos-2", device_map=device)
print("Model loaded successfully on device:", device)

# Predicting functions
We will define the necessary functions to process the incoming data and make predictions using the Chronos-2 model. This will include functions for data preprocessing, anomaly detection, and handling the results.


In [None]:
@torch.no_grad()
def predictNextPoint(chronos:Chronos2Pipeline, context:myQueue)->np.ndarray:
    """Predict the next point in the time series given the context using the Chronos-2 model.
    
    Args:
        chronos (Chronos2Pipeline): The Chronos-2 model pipeline for time series forecasting.
        context (myQueue): A list of historical data points to be used as context for prediction.
        
    Returns:
        np.ndarray: The predicted next point in the time series.
    """
    return chronos.predict(torch.tensor(context.getData()).T.unsqueeze(0), prediction_length=1, context_length=len(context), cross_learning=False)[0].squeeze(-1).numpy()

In [None]:
def isAnomaly(prediction:np.ndarray, actual:np.ndarray, threshold:float=5.99)->bool:
    """Determine if the actual value is an anomaly based on the predicted distribution
    and a quantile threshold.
    
    Args:
        prediction (np.ndarray): The predicted value from the model.
        actual (np.ndarray): The actual observed value.
        threshold (float): The chi-squared threshold for anomaly detection (default is 5.99 for 95% confidence).
        
    Returns:
        anomaly (int): 1 if anomaly, 0 if not, -1 if not enough data.
    """
    # Stima spread usando IQR
    spread = (prediction[:, 15] - prediction[:, 5]) / 1.35  # Approx standard deviation
        
    # Mahalanobis distance semplificato (assume indipendenza)
    # Per correlazione vera, serve covarianza completa                
    return int(np.sqrt(np.sum(((actual - prediction[:, 10]) / (spread + 1e-8)) ** 2)) > threshold)

In [None]:
def anomalyInference(data:np.ndarray, context:myQueue, chronos:Chronos2Pipeline, threshold:float=5.99)->int:
    """Process incoming data, make predictions, and determine if the new data point is an anomaly.
    
    Args:
        data (np.ndarray): The new incoming data point to be evaluated.
        context (myQueue): A queue containing historical data points for context.
        chronos (Chronos2Pipeline): The Chronos-2 model pipeline for making predictions.
        threshold (float): The chi-squared threshold for anomaly detection (default is 5.99 for 95% confidence).
        
    Returns:
        int: 1 if the new data point is an anomaly, 0 if it is not, -1 if there is not enough data.
    """
    if not context.isFull():
        return -1 # Not enough data to make a prediction yet

    # Get the predicted next point from the model
    prediction = predictNextPoint(chronos, context)
    
    # Determine if the actual value is an anomaly
    return isAnomaly(prediction, data, threshold)

# Message handling
We will implement a message handling function that will be called whenever a new message is received from the MQTT broker. This function will handle the incoming stream of data to check its correctness and perform anomaly detection on it.

In [None]:
def checkCorrectness(msg:dict[str, int|float|list[dict[str, str|float]]])->tuple[bool, list[str]]:
    """Check if the message has the correct structure.
    
    Args:
        msg (dict): The message to check.

    Returns:
        tuple: A tuple containing a boolean indicating if the message has the correct structure and a list of error messages.
    """
    errorMsg:list[str] = []
    
    # Check keys
    if set(msg.keys()) != {'mac_address', 'timestamp', 'data'}:
        errorMsg.append("Message keys are not correct")
    
    # Check types of principal keys
    if 'mac_address' in msg and not isinstance(msg['mac_address'], str):
        errorMsg.append("mac_address is not a string")
    if 'timestamp' in msg and not isinstance(msg['timestamp'], int):
        errorMsg.append("timestamp is not an integer")
        
    # Check data entries
    if 'data' in msg and isinstance(msg['data'], dict) and set(msg['data'].keys()) == {'temperature', 'humidity'}:
        if "temperature" not in msg['data'] or not isinstance(msg['data']['temperature'], (int, float)):
                errorMsg.append("Data entry temperature is not a number")
        if 'humidity' not in msg['data'] or not isinstance(msg['data']['humidity'], (int, float)):
                errorMsg.append("Data entry humidity is not a number")
    else:
        errorMsg.append("Data is not a dictionary or does not have the correct keys")
    
    return not len(errorMsg), errorMsg

In [None]:
def on_connect(client, userdata, flags, rc):
    if rc == 0:
        print("Connected successfully to MQTT broker")
    else:
        print(f"Connection failed with code {rc}")            
    
    # Subscribe to a topic when the client connects
    client.subscribe('s344860')

In [None]:
def on_message(client, userdata, msg):    
    # Decode the message
    message = json_loads(msg.payload.decode())
    queue = userdata['context']
    
    is_correct, errors = checkCorrectness(message)

    if not is_correct:
        print("Received incorrect message. Errors:", errors)
        return
        
    values = np.array([message['data']['temperature'], message['data']['humidity']])
    
    score = anomalyInference(values, queue, userdata['pipeline'], userdata['threshold'])

    queue.push(values, score, message['timestamp'])

    timestamp = datetime.fromtimestamp(message['timestamp']/1_000).strftime('%Y-%m-%d %H:%M:%S')

    match score:
        case 1: print(f"\033[91m ANOMALY DETECTED: (Timestamp: {timestamp}, Temperature: {values[0]:.1f}°C, Humidity: {values[1]:.1f}%)\033[0m")
        case 0: print(f"\033[92m NOT AN ANOMALY: (Timestamp: {timestamp}, Temperature: {values[0]:.1f}°C, Humidity: {values[1]:.1f}%)\033[0m")
        case -1: print(f"\033[94m Not enough data, need {queue.max_size - len(queue)} more for context \033[0m")
        case _: raise ValueError(f"Unexpected score value: {score}")

# Client
We will be using the Paho MQTT client to connect to our MQTT broker and subscribe to the relevant topics for receiving data. The client will be set up to handle incoming messages and process them for anomaly detection.

In [None]:
client = mqtt.Client()
CONTEXT_LENGTH = 32
# Threshold: chi-squared(2 dof, 95%) ≈ 5.99
# Più conservativo: 7.81 (99%)
# mediamente bilanciato: 5.99 (95%)
# Più aggressivo: 4.61 (90%)
THRESHOLD = 5.99

client.on_connect = on_connect
client.on_message = on_message

In [None]:
client.connect('broker.emqx.io', 1883)

# Set user data
client.user_data_set({'pipeline': pipeline,'context':  myQueue(CONTEXT_LENGTH), 'threshold': THRESHOLD})


client.loop_forever()