In [None]:
import requests
import sys
import random
import torch
import os
import copy
import pandas as pd
from collections import OrderedDict

import binascii

import Crypto.Random
from Crypto.Hash import SHA
from Crypto.PublicKey import RSA
from Crypto.Signature import PKCS1_v1_5

import hashlib
import json
from time import time, perf_counter, process_time
from urllib.parse import urlparse
from uuid import uuid4

from tqdm import tqdm
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim
from Models import Mnist_2NN, Mnist_CNN
from clients import ClientsGroup, client

#PQC Package
from pqcrypto.sign.falcon_512 import generate_keypair, sign, verify
from pqcrypto.sign.falcon_1024 import generate_keypair, sign, verify
from pqcrypto.sign.rainbowVc_classic import generate_keypair, sign, verify
from pqcrypto.sign.rainbowVc_cyclic import generate_keypair, sign, verify
from pqcrypto.sign.dilithium4 import generate_keypair, sign, verify
from pqcrypto.sign.dilithium2 import generate_keypair, sign, verify

In [None]:
NUM_MINER = 2
NUM_WORKER = 100
ROUND = 10

# Blockchain
MINING_SENDER = "THE BLOCKCHAIN"
MINING_REWARD = 1
MINING_DIFFICULTY = 5

# FedAVG
gpu = '0'
cfraction = 0.1
internal = 15
data_ts = 100
epoch = 5
batchsize = 10
model_name = 'mnist_cnn'
learning_rate = 0.01
val_freq = 1
save_path = './checkpoints'
IID = 1

In [None]:
class Blockchain:

    def __init__(self):

        self.transactions = []
        self.chain = []
        # Generate random number to be used as node_id
        self.node_id = str(uuid4()).replace('-', '')
        # Create genesis block
        self.create_block(0, '00')
        self.num_forking = 0
        self.name = None

    def register_node(self, node_url):
        """
        Add a new node to the list of nodes
        """
        # Checking node_url has valid format
        parsed_url = urlparse(node_url)
        if parsed_url.netloc:
            self.nodes.add(parsed_url.netloc)
        elif parsed_url.path:
            # Accepts an URL without scheme like '192.168.0.5:5000'.
            self.nodes.add(parsed_url.path)
        else:
            raise ValueError('Invalid URL')

    def verify_transaction_signature(self, sender_address, signature, transaction):
        """
        Check that the provided signature corresponds to transaction
        signed by the public key (sender_address)
        """
        public_key = RSA.importKey(binascii.unhexlify(sender_address))
        verifier = PKCS1_v1_5.new(public_key)
        h = SHA.new(str(transaction).encode('utf8'))
        return verifier.verify(h, binascii.unhexlify(signature))

    def submit_transaction(self, value):
        """
        Add a transaction to transactions array if the signature verified
        """
        self.transactions.append(value)
        return len(self.chain) + 1


    def create_block(self, nonce, previous_hash):
        """
        Add a block of transactions to the blockchain
        """
        block = {'block_number': len(self.chain) + 1,
                 'timestamp': time(),
                 'transactions': self.transactions,
                 'nonce': nonce,
                 'previous_hash': previous_hash}

        # Reset the current list of transactions
        self.transactions = []

        self.chain.append(block)
        return block

    def hash(self, block):
        """
        Create a SHA-256 hash of a block
        """
        # We must make sure that the Dictionary is Ordered, or we'll have inconsistent hashes
        block_string = json.dumps(block, sort_keys=True).encode()

        return hashlib.sha256(block_string).hexdigest()

    def proof_of_work(self):
        """
        Proof of work algorithm
        """
        last_block = self.chain[-1]
        last_hash = self.hash(last_block)

        nonce = 0
        while self.valid_proof(self.transactions, last_hash, nonce) is False:
            nonce += 1

        return nonce

    def valid_proof(self, transactions, last_hash, nonce, difficulty=MINING_DIFFICULTY):
        """
        Check if a hash value satisfies the mining conditions. This function is used within the proof_of_work function.
        """
        guess = (str(transactions)+str(last_hash)+str(nonce)).encode()
        guess_hash = hashlib.sha256(guess).hexdigest()
        return guess_hash[:difficulty] == '0'*difficulty

    def valid_chain(self, chain):
        """
        check if a bockchain is valid
        """
        last_block = chain[0]
        current_index = 1

        while current_index < len(chain):
            block = chain[current_index]
            # print(last_block)
            # print(block)
            # print("\n-----------\n")
            # Check that the hash of the block is correct
            if block['previous_hash'] != self.hash(last_block):
                return False

            # Check that the Proof of Work is correct
            # Delete the reward transaction
            transactions = block['transactions']

            if not self.valid_proof(transactions, block['previous_hash'], block['nonce'], MINING_DIFFICULTY):
                return False

            last_block = block
            current_index += 1

        return True
    
    def mine(self):
        check = 0
        global sums
        global NUM_MINER
        if sums < 1:
            # We run the proof of work algorithm to get the next proof...
            last_block = self.chain[-1]
            nonce = self.proof_of_work()

            # Forge the new Block by adding it to the chain
            previous_hash = self.hash(last_block)
            block = self.create_block(nonce, previous_hash)
            check += 1
            
            if check<=1:
                sums += 1
            
                

    def resolve_conflicts(self, miners):
        """
        Resolve conflicts between blockchain's nodes
        by replacing our chain with the longest one in the network.
        """
        new_chain = None

        # We're only looking for chains longer than ours
        max_length = len(self.chain)

        # Grab and verify the chains from all the nodes in our network
        for i in range(len(miners)):
            if miners[i] != self.name:
                chain = globals()[miners[i]].chain
                length = len(chain)
                # Check if the length is longer and the chain is valid
                if length > max_length and self.valid_chain(chain):
                    max_length = length
                    new_chain = chain

        # Replace our chain if we discovered a new, valid chain longer than ours
        if new_chain:
            self.chain = new_chain
            self.num_forking += 1
            return True

        return False

def test_mkdir(path):
    if not os.path.isdir(path):
        os.mkdir(path)
        
def to_json_dict(dictionary):
    """
    Convert dict to dict with leafs only being strings. So it recursively makes keys to strings
    if they are not dictionaries.

    Use case:
        - saving dictionary of tensors (convert the tensors to strins!)
        - saving arguments from script (e.g. argparse) for it to be pretty

    e.g.

    """
    d = {k: str(dictionary[k]) for k in dictionary}
    return d

In [None]:
blockchain = Blockchain()

In [None]:
os.environ['CUDA_VISIBLE_DEVICES'] = gpu
dev = torch.device(
    "cuda") if torch.cuda.is_available() else torch.device("cpu")

net = None
if model_name == 'mnist_2nn':
    net = Mnist_2NN()
elif model_name == 'mnist_cnn':
    net = Mnist_CNN()

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    net = torch.nn.DataParallel(net)
net = net.to(dev)

loss_func = F.cross_entropy
opti = optim.SGD(net.parameters(), lr=learning_rate)

myClients = ClientsGroup('mnist', IID, NUM_WORKER, dev)
testDataLoader = myClients.test_data_loader

num_in_comm = int(max(NUM_WORKER * cfraction, 1))

global_parameters = {}
for key, var in net.state_dict().items():
    global_parameters[key] = var.clone()

# Create dataframe to record performance    
df = pd.DataFrame(columns=["acc", "cpu_cost_time", "cpu_total_time","cost_time", "total_time"])
acc = []
cpu_start_time = process_time()
cpu_total_time = []
cpu_cost_time = []
start_time = time()
total_time = []
cost_time = []

# RSA
# random_gen = Crypto.Random.new().read
# sender_private_key = RSA.generate(1024, random_gen)
# sender_public_key = sender_private_key.publickey()
# private_key = binascii.hexlify(sender_private_key.exportKey(format='DER')).decode('ascii')
# public_key = binascii.hexlify(sender_public_key.exportKey(format='DER')).decode('ascii')
# print('pubsize', sys.getsizeof(public_key))
# print('private_key', sys.getsizeof(public_key))

# FALCON 512
# falcon512_public_key, falcon512_secret_key = generate_keypair()
# print('pubsize', sys.getsizeof(falcon512_public_key))
# print('private_key', sys.getsizeof(falcon512_secret_key))

# FALCON 1024
# falcon1024_public_key, falcon1024_secret_key = generate_keypair()

# Dilithium2
dilithium2_public_key, dilithium2_secret_key = generate_keypair()

# Dilithium 4
# dilithium4_public_key, dilithium4_secret_key = generate_keypair()

# RainbowV Classic
# rainbow_classic_public_key, rainbow_classic_secret_key = generate_keypair()

parameters = {}
# Start communicate
for i in range(1, ROUND+1):
    # Record Start time
    cpu_begin_time = process_time()
    begin_time = time()
    print("This communicate round ", i)
    avgverifycost =0
    avgsigncost = 0
    avgsigsize = 0
    counts =0

    order = np.random.permutation(NUM_WORKER)
    clients_in_comm = ['client{}'.format(i) for i in order[0:num_in_comm]]

    # Create miners
    for miner in range(NUM_MINER):
        parameters['miner' + str(miner)] = None

    # Each client update local model
    for client in tqdm(clients_in_comm):
        local_parameters = myClients.clients_set[client].localUpdate(data_ts, epoch, batchsize, net,
                                                                     loss_func, opti, global_parameters)
        # RSA
        # privatekey = RSA.importKey(binascii.unhexlify(private_key))
        # signer = PKCS1_v1_5.new(privatekey)
        # abst = perf_counter()
        # h = SHA.new(str(str(local_parameters)).encode('utf8'))
        # signature = binascii.hexlify(signer.sign(h)).decode('ascii')
        # print(f'coast:{perf_counter() - t:.8f}s')
        # print('sigsize', sys.getsizeof(signature))
        # publickey = RSA.importKey(binascii.unhexlify(public_key))
        # verifier = PKCS1_v1_5.new(publickey)
        # h = SHA.new(str(str(local_parameters)).encode('utf8'))
        # verifier.verify(h, binascii.unhexlify(signature))

        # FALCON 512
        # t = perf_counter()
        # signature = sign(falcon512_secret_key, bytes(str(local_parameters), encoding = "utf8"))
        # print(f'cost:{perf_counter() - t:.8f}s')
        # print('sigsize', sys.getsizeof(signature))
        # t = perf_counter()
        # verify(falcon512_public_key, bytes(str(local_parameters), encoding = "utf8"), signature)
        # print(f'verifycost:{perf_counter() - t:.8f}s')

        # FALCON 1024
        # signature = sign(falcon1024_secret_key, bytes(str(local_parameters), encoding = "utf8"))
        # verify(falcon1024_public_key, bytes(str(local_parameters), encoding = "utf8"), signature)

        # Dilithium2
        t = perf_counter()
        signature = sign(dilithium2_secret_key, bytes(str(local_parameters), encoding = "utf8"))
        avgsigncost += perf_counter() - t
        avgsigsize += sys.getsizeof(signature)
        t = perf_counter()
        verify(dilithium2_public_key, bytes(str(local_parameters), encoding = "utf8"), signature)
        counts+=1
        avgverifycost += perf_counter() - t
        
        # Dilithium 4
        # signature = sign(dilithium4_secret_key, bytes(str(local_parameters), encoding = "utf8"))
        # verify(dilithium4_public_key, bytes(str(local_parameters), encoding = "utf8"), signature)

        # RainbowV
        # signature = sign(rainbow_classic_secret_key, bytes(str(local_parameters), encoding = "utf8"))
        # verify(rainbow_classic_public_key, bytes(str(local_parameters), encoding = "utf8"), signature)

        miner_index = random.randint(0, NUM_MINER-1)
        if parameters['miner' + str(miner_index)] is None:
            parameters['miner' + str(miner_index)] = {}
            for key, var in local_parameters.items():
                parameters['miner' + str(miner_index)][key] = var.clone()
        else:
            for var in parameters['miner' + str(miner_index)]:
                parameters['miner' + str(miner_index)][var] = parameters['miner' +
                                                                         str(miner_index)][var] + local_parameters[var]
    print(f'avgverifycost:{avgverifycost/counts:.8f}s')
    print(f'avgsigncost:{avgsigncost/counts:.8f}s')
    print('avgsigsize', avgsigsize/counts)
    
    if (num_in_comm*i)%internal ==0:
        # Miners exchange received weights
        sum_parameters = {}
        for para in parameters:
            if parameters[para]:
                if not sum_parameters:
                    for key, var in parameters[para].items():
                        sum_parameters[key] = var.clone()
                else:
                    for key, var in parameters[para].items():
                        sum_parameters[key] = sum_parameters[key] + \
                            parameters[para][key]

        # Update global weights by FedAVG
        for var in global_parameters:
            global_parameters[var] = (sum_parameters[var] / num_in_comm)

        # Do proof of work
        last_block = blockchain.chain[-1]
        nonce = blockchain.proof_of_work()

        # Forge the new Block by adding it to the chain
        blockchain.submit_transaction(to_json_dict(global_parameters))
        previous_hash = blockchain.hash(last_block)
        block = blockchain.create_block(nonce, previous_hash)

        # Each client update local model by golbal weights
        with torch.no_grad():
            net.load_state_dict(global_parameters, strict=True)
            sum_accu = 0
            num = 0
            for data, label in testDataLoader:
                data, label = data.to(dev), label.to(dev)
                preds = net(data)
                preds = torch.argmax(preds, dim=1)
                sum_accu += (preds == label).float().mean()
                num += 1
            print('accuracy: {}'.format(sum_accu / num))
            # Record accuarcy in each communication round
            acc.append((sum_accu / num).item())
        
        parameters = {}

        # Record end time and compute delay
        cpu_end_time = process_time()
        end_time = time()
        cpu_run_time = cpu_end_time-cpu_begin_time
        run_time = end_time-begin_time
        cpu_current_time = cpu_end_time - cpu_start_time
        current_time = end_time - start_time
        cpu_cost_time.append(cpu_run_time)
        cost_time.append(run_time)
        cpu_total_time.append(cpu_current_time)
        total_time.append(current_time)

df.cpu_cost_time = cpu_cost_time
df.cpu_total_time = cpu_total_time
df.cost_time = cost_time
df.total_time = total_time
df.acc = acc
#  df.num_worker = NUM_WORKER
#  df.num_miner = NUM_MINER
# df.rounds = [x for x in range(1, ROUND +1 )]
# if os.path.exists("bfl-lr.csv"):
#     df1 = pd.read_csv("bfl-lr.csv")
#     df = df1.append(df)
# df.to_csv("bfl-lr.csv", index=False)