In [10]:
import redis
import pickle
from collections import defaultdict
import threading
import time
import hmac
import hashlib
import json
from sklearn.ensemble import IsolationForest
import numpy as np
import uuid
from collections import deque
from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.hkdf import HKDF
import base64


# Encryption key (this should be securely shared between aggregator and nodes)
key = b'password'
hkdf = HKDF(
    algorithm=hashes.SHA256(),  
    length=32,
    salt=None,    
    info=None,    # You may also be able to remove this line
     backend=default_backend() )
key = base64.urlsafe_b64encode(hkdf.derive(key))
cipher_suite = Fernet(key)

SECRET_KEY = b'my_secret_key'

class KNNAggregator:
    def __init__(self, node_count, message_integrity=False, message_tampering=False, rate_limit_bool=False,encypt = False):
        self.redis_client = redis.StrictRedis(host='localhost', port=6379, db=0, password='secret_key')
        self.subscribe_channel = 'knn_response_channel'
        self.query_channel = 'knn_query_channel'  # New channel for receiving search queries
        self.node_count = node_count
        self.weights = {1: 1.0, 2: 1.0, 3:1.0}
        self.requests = {}
        self.active_request_ids = set()
        self.message_integrity = message_integrity
        self.message_tampering = message_tampering
        self.encrypt = encypt
        self.rate_limit = 10
        self.rate_limit_bool = rate_limit_bool
        self.request_times = deque()
        self._listen_thread = threading.Thread(target=self.listen_for_responses)
        self._listen_thread.start()
        self._query_thread = threading.Thread(target=self.listen_for_queries)
        self._query_thread.start()

    def sign_message(self, message):
        message_json = json.dumps(message)
        signature = hmac.new(SECRET_KEY, message_json.encode(), hashlib.sha256).hexdigest()
        return {'message': message, 'signature': signature}

    def verify_message(self, signed_message):
        message_json = json.dumps(signed_message['message'])
        expected_signature = hmac.new(SECRET_KEY, message_json.encode(), hashlib.sha256).hexdigest()
        return hmac.compare_digest(expected_signature, signed_message['signature'])
    
    def check_rate_limit(self):
        current_time = time.time()
        while self.request_times and current_time - self.request_times[0] > 1:
            self.request_times.popleft()
        if len(self.request_times) < self.rate_limit:
            self.request_times.append(current_time)
            return True
        return False
    
    def listen_for_responses(self):
        pubsub = self.redis_client.pubsub()
        pubsub.subscribe(self.subscribe_channel)
        for message in pubsub.listen():
            if message['type'] == 'message':
                
                signed_response = pickle.loads(message['data'])
                # print(signed_response)
                if self.message_integrity:
                    if self.verify_message(signed_response):
                        response = signed_response['message']
                    else:
                        print(f'The message sent by node {signed_response['message']['node_id']} is not correctly signed')
                        continue
                else:
                    response = signed_response['message']

                request_id = response['request_id']
                if request_id in self.active_request_ids:
                    if request_id in self.requests:
                        self.requests[request_id].append(response)
                    else:
                        self.requests[request_id] = [response]

    def listen_for_queries(self):
        
        pubsub = self.redis_client.pubsub()
        pubsub.subscribe(self.query_channel)
        for message in pubsub.listen():
            if message['type'] == 'message':
                if (self.rate_limit_bool and self.check_rate_limit()) or not self.rate_limit_bool:
                    query_data = pickle.loads(message['data'])
                    sample = query_data['sample']
                    distance = query_data['distance']
                    voting_method = query_data['voting_method']
                    request_id = query_data['request_id']
                    self.send_request(sample, distance,request_id)
                    self.handle_responses(request_id, voting_method)
                else:
                    print("dropped" ,request_id)
        
                
    def send_request(self, sample, distance,request_id):
        request_data = {'sample': sample, 'request_id': request_id}
        
        request_data_pickled = pickle.dumps(request_data)
        if(self.encrypt):
            request_data_pickled = cipher_suite.encrypt(request_data_pickled)
        self.requests[request_id] = []
        self.active_request_ids.add(request_id)
        for node_id in range(1, self.node_count + 1):
            self.redis_client.publish(f'knn_{distance}_request_channel_{node_id}', request_data_pickled)
    
    
    # 
    def handle_responses(self, request_id, voting_method):
        
        timeout = 60  # Timeout in seconds
        start_time = time.time()
        while True:
            if time.time() - start_time > timeout:
                print("Timeout reached, terminating script.")
                self.redis_client.publish('knn_query_response_channel', pickle.dumps({'request_id': request_id, 'result': []}))
                break
                
            # print(len(self.requests.get(request_id, [])))
            if len(self.requests.get(request_id, [])) >= self.node_count:
                if voting_method == 'majority_voting':
                    result = self.majority_voting(request_id)
                elif voting_method == 'min_distance':
                    result = self.min_distance(request_id)
                elif voting_method == 'weighted_voting':
                    result = self.weighted_voting(request_id)
                elif voting_method == 'distance_based_aggregation':
                    result = self.distance_based_aggregation(request_id)
                else:
                    result = None

                self.redis_client.publish('knn_query_response_channel', pickle.dumps({'request_id': request_id, 'result': result}))
                break
            time.sleep(0.1)  # Sleep briefly to avoid busy-waiting

    def create_feature_vectors(self, request_id):
        feature_vectors = []
        responses = self.requests[request_id]
        for response in responses:
            distances = response['distances']
            predictions = response['prediction']
            feature_vector = np.concatenate([distances, predictions])
            feature_vectors.append(feature_vector)
        return np.array(feature_vectors)

    def detect_anomalies_isolation_forest(self, request_id, contamination=0.1):
        if self.message_tampering:
            time.sleep(1)
            feature_vectors = self.create_feature_vectors(request_id)
            model = IsolationForest(contamination=contamination)
            model.fit(feature_vectors)
            predictions = model.predict(feature_vectors)
            anomaly_indices = np.where(predictions == -1)[0]
            responses = np.array(self.requests[request_id])
            mask = np.ones(len(responses), dtype=bool)
            mask[anomaly_indices] = False
            self.requests[request_id] = responses[mask].tolist()

    def majority_voting(self, request_id):
        self.detect_anomalies_isolation_forest(request_id)
        responses = self.requests[request_id]
        all_predictions = [prediction for response in responses for prediction in response['prediction']]
        # Find the majority class
        final_prediction = max(set(all_predictions), key=all_predictions.count)
        del self.requests[request_id]
        self.active_request_ids.discard(request_id)
        return final_prediction

    def min_distance(self, request_id):
        self.detect_anomalies_isolation_forest(request_id)
        responses = self.requests[request_id]
        all_predictions = []
        all_distances = []
        for response in responses:
            all_distances.append(response['distances'])
            all_predictions.append(response['prediction'])
        flattened_distances = [item for sublist in all_distances for item in sublist]
        flattened_predictions = [item for sublist in all_predictions for item in sublist]
        min_index = flattened_distances.index(min(flattened_distances))
        final_prediction = flattened_predictions[min_index]
        del self.requests[request_id]
        self.active_request_ids.discard(request_id)
        return final_prediction

    def weighted_voting(self, request_id):
        self.detect_anomalies_isolation_forest(request_id)
        responses = self.requests[request_id]
        weighted_votes = defaultdict(float)
        for response in responses:
            node_id = response['node_id']
            predictions = response['prediction']
            weight = self.weights[node_id]
            for prediction in predictions:
                weighted_votes[prediction] += weight
        final_prediction = max(weighted_votes, key=weighted_votes.get)
        del self.requests[request_id]
        self.active_request_ids.discard(request_id)
        return final_prediction

    def distance_based_aggregation(self, request_id):
        self.detect_anomalies_isolation_forest(request_id)
        responses = self.requests[request_id]
        distance_weighted_votes = defaultdict(float)
        for response in responses:
            node_id = response['node_id']
            predictions = response['prediction']
            distances = response['distances']
            for i, prediction in enumerate(predictions):
                weight = 1 / (distances[i] + 1e-5)
                distance_weighted_votes[prediction] += weight
        final_prediction = max(distance_weighted_votes, key=distance_weighted_votes.get)
        del self.requests[request_id]
        self.active_request_ids.discard(request_id)
        return final_prediction


# Initialize the aggregator

aggregator = KNNAggregator(node_count=3)

0
{'message': {'node_id': 2, 'prediction': [7, 7, 6], 'distances': [0.3971969268473038, 0.1397830694330423, 0.32300462126035157], 'request_id': '156153bd-9568-4074-9e65-69c7ce211079'}, 'signature': 'f'}
The message sent by node 2 is not correctly signed
{'message': {'node_id': 1, 'prediction': [6, 6, 6], 'distances': [0.02518019982513464, 0.05141108189816079, 0.05162731029869594], 'request_id': '156153bd-9568-4074-9e65-69c7ce211079'}, 'signature': '0867fd3f888d2610a2c4f6818e01b3c525158ab79931e26ea5e4f8099d2e9096'}
{'message': {'node_id': 2, 'prediction': [6, 6, 6], 'distances': [0.030774054370998005, 0.04430344274200948, 0.050306759120160915], 'request_id': '156153bd-9568-4074-9e65-69c7ce211079'}, 'signature': '917e7db954f9eae34632842ed839650b586d8cf962291b26e624a1f179f623e9'}
{'message': {'node_id': 3, 'prediction': [6, 6, 6], 'distances': [0.0373170632272567, 0.039360543012593174, 0.0504187186398124], 'request_id': '156153bd-9568-4074-9e65-69c7ce211079'}, 'signature': '8300b6cf421e5e

In [4]:
aggregator.rate_limit_bool = True

In [5]:
aggregator.rate_limit_bool = False

In [11]:
# Defense aganist message tampering during transmission
aggregator.message_integrity = True

In [None]:
# Allow tampering during transmission
aggregator.message_integrity = False

In [4]:
aggregator.encrypt = True

In [5]:
aggregator.encrypt = False

In [13]:
# Defense aganist label flipping
aggregator.message_tampering = True

In [14]:
# Allow  label flipping
aggregator.message_tampering = False