In [None]:
import socket
import struct
import time
import scipy.io
import numpy as np
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from torchvision.transforms import ToTensor
from flwr_datasets.visualization import plot_label_distributions
from numba import njit, jit
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from cython_decoder import cython_sc_decoding

In [2]:
host = '127.0.0.1'
port = 5000
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((host, port))
num_nodes = 5

In [None]:
num_nodes = 20
server_socket.listen((num_nodes+1)*2)
node_s = []
node_r = []

try:
    while True:
        client_socket, addr = server_socket.accept()
        server_socket.settimeout(10)
        data = client_socket.recv(1024).decode()
        if data == "Server-R":
            server_s = client_socket
        elif data == "Server-S":
            server_r = client_socket
        elif data == "Node-R":
            node_s.append(client_socket)
        elif data == "Node-S":
            node_r.append(client_socket)
        client_socket.sendall(struct.pack('I',len(b"start"))+b"start")
except socket.timeout:
    print('Timeout')
    server_socket.settimeout(None)

for tmp_socket in node_r:
    tmp_socket.recv(1024)
server_r.recv(65536)

In [None]:
for tmp_socket in node_r:
    tmp_socket.close()
for tmp_socket in node_s:
    tmp_socket.close()
server_s.close()
server_r.close()

In [3]:
fds = FederatedDataset(
    dataset="cifar10",
    partitioners={
        "train": DirichletPartitioner(
            num_partitions=num_nodes,
            partition_by="label",
            alpha=0.1,
            seed=42,
            min_partition_size=0,
        ),
    },
)

In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

def train_transforms(batch):
  transforms = transform_train
  batch["img"] = [transforms(img) for img in batch["img"]]
  return batch

def test_transforms(batch):
    transforms = transform_test
    batch["img"] = [transforms(img) for img in batch["img"]]
    return batch

train_loader=[]
test_loader=[]
for i in range(num_nodes):
    partition_train_test = fds.load_partition(i, "train").train_test_split(0.1)
    partition_train = partition_train_test["train"].with_transform(train_transforms)
    partition_test = partition_train_test["test"].with_transform(test_transforms)
    # centralized_dataset = fds.load_split("test").with_transform(test_transforms)
    train_loader.append(DataLoader(partition_train, batch_size=512, shuffle=True, num_workers=16))
    test_loader.append(DataLoader(partition_test, batch_size=100, shuffle=False, num_workers=16))

In [5]:
from models.vit_small import ViT
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

net = []
optimizer = []
scheduler = []
criterion = []
scaler = []
for i in range(num_nodes):
    net.append(ViT(
        image_size = 32,
        patch_size = 4,
        num_classes = 10,
        dim = 32,
        depth = 6,
        heads = 8,
        mlp_dim = 32,
        dropout=0.1,
        emb_dropout=0.1
    ).to(device))


    optimizer.append(optim.Adam(net[i].parameters(), lr=0.001))
    scheduler.append(torch.optim.lr_scheduler.CosineAnnealingLR(optimizer[i], 20))
    criterion.append(nn.CrossEntropyLoss())
    scaler.append(torch.cuda.amp.GradScaler(enabled=True))

server_net = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 32,
    depth = 6,
    heads = 8,
    mlp_dim = 32,
    dropout=0.1,
    emb_dropout=0.1
).to(device)

In [6]:
def train_model(model: nn.Module, 
                train_loader: DataLoader, 
                criterion: nn.Module, 
                device: torch.device, 
                scaler: torch.cuda.amp.GradScaler, 
                optimizer: torch.optim.Optimizer,
                epoch: int,
                nodes: int):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for batch in train_loader:
        inputs = batch["img"].to(device)
        labels = batch["label"].to(device)
        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        total_loss += loss.item()
        total_samples += labels.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == labels).sum().item()
    print(f"Nodes: {nodes}, Epoch: {epoch},Train Loss: {total_loss / total_samples:.4f}, Train Accuracy: {total_correct / total_samples:.4f}")

def evaluate_model(model: nn.Module, 
                   test_loader: DataLoader, 
                   criterion: nn.Module, 
                   device: torch.device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs = batch["img"].to(device)
            labels = batch["label"].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            total_samples += labels.size(0)
            _, preds = torch.max(outputs, 1)
            total_correct += (preds == labels).sum().item()
    print(f"Validation Loss: {total_loss / total_samples:.4f}, Validation Accuracy: {total_correct / total_samples:.4f}\n\t")

In [None]:
for cli in range(num_nodes):
    for i in range(20):
        train_model(net[cli], train_loader[cli], criterion[cli], device, scaler[cli], optimizer[cli], i, cli)
        evaluate_model(net[cli], test_loader[cli], criterion[cli], device)
        scheduler[cli].step()
    scheduler[cli] = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer[cli], 20)
    

In [None]:
def numpy_to_bit(array: np.ndarray) -> np.ndarray:
    tmp_byte=np.frombuffer(array.tobytes(),dtype=np.uint8)
    bit_stream = ''.join(format(byte, '08b') for byte in tmp_byte)
    bit_array = np.array([int(bit) for bit in bit_stream], dtype=np.int8)
    return bit_array

def bit_to_numpy(bit_array: np.ndarray) -> np.ndarray:
    bit_stream = ''.join(str(int(bit)) for bit in bit_array)
    byte_array_back = np.array([int(bit_stream[i:i+8], 2) for i in range(0, len(bit_stream), 8)], dtype=np.uint8)
    int_array_back = np.frombuffer(byte_array_back.tobytes(), dtype=np.float32)
    return int_array_back

'''
Encoding and decoding function
'''
@jit(nopython=True)
def encode(u: np.ndarray) -> np.ndarray:
    N = u.shape[0]  # Get the length of u
    n = int(np.log2(N))  # Calculate the log base 2 of N

    if n == 1:
        x = np.array([(u[0] + u[1]) % 2, u[1]],dtype=np.int8)
        return x
    else:
        x1 = encode(np.mod(u[:N//2] + u[N//2:], 2))
        x2 = encode(u[N//2:])
        x = np.concatenate((x1, x2))
        return x

@jit(nopython=True)
def rvsl(y: np.ndarray) -> np.ndarray:
    N = y.shape[0]
    if N == 2:
        return y
    else:
        return np.concatenate((rvsl(y[0:N:2]), rvsl(y[1:N:2])))

def data_process(array):
    bit_array = numpy_to_bit(array)
    array_len = len(bit_array)
    current_array = []
    for i in range(0, array_len, 512):
        sub_array = bit_array[i:i+512]
        if len(sub_array) < 512:
            padding = np.ones((512 - len(sub_array)), dtype=bit_array.dtype)
            sub_array = np.concatenate((sub_array, padding))
        current_array.append(sub_array)
    current_idx = i // 512 + 1
    return current_array, current_idx, array_len

def numpy_array_to_udp_packet(bit_array):
    # Convert the numpy array of bits to a string of bits
    bit_string = ''.join(str(int(bit)) for bit in bit_array)
    # Convert the bit string to bytes
    byte_array = bytearray(int(bit_string[i:i+8], 2) for i in range(0, len(bit_string), 8))
    return byte_array

def udp_packet_to_numpy_array(packet):
    # Convert the byte array back to a bit string
    bit_string = ''.join(format(byte, '08b') for byte in packet)
    # Convert the bit string to a numpy array of floats
    bit_array = np.array([int(bit) for bit in bit_string], dtype=np.int8)
    return bit_array

def data_generate(bit_array:np.ndarray, data_idx:np.ndarray) -> np.ndarray:
    u=np.zeros(1024,dtype=np.int8)
    u[data_idx] = bit_array
    x = encode(u)
    x = rvsl(x)
    return x

def codeword_generate(array_dict, data_idx):
    split_bit = []
    codeword_idx = []
    bit_array_len = []
    codeword_idx.append(0)
    array_process = partial(data_process)
    with ProcessPoolExecutor() as executor:
        for current_array, current_idx, array_len in executor.map(array_process, [array_dict[name] for name in array_dict]):
            split_bit.extend(current_array)
            codeword_idx.append(codeword_idx[-1] + current_idx)
            bit_array_len.append(array_len)
    executor.shutdown(wait=True)
    del executor, array_process
    time1 = time.time()
    encode_partial = partial(data_generate, data_idx=data_idx)
    with ProcessPoolExecutor() as executor:
        codeword = list(executor.map(encode_partial, split_bit))
    executor.shutdown(wait=True)
    time2 = time.time()
    print(f'Encode time: {time2-time1}')

    del executor, encode_partial
    return np.array(codeword,dtype=np.int8), codeword_idx, bit_array_len

def packet_diffusion(codeword, block_len, packet_idx):
    codeword_len = codeword[0].shape[0]
    udp_packet = []
    for idx, i in enumerate(range(0, codeword_len, block_len)):
        # tmp_packet = np.concatenate([tmp_codeword[i:i+block_len] for tmp_codeword in codeword])
        tmp_packet = codeword[:, packet_idx[i:i+block_len]].flatten()
        tmp_udp_packet = struct.pack("I",idx) + numpy_array_to_udp_packet(tmp_packet)
        udp_packet.append(tmp_udp_packet)
    return udp_packet

def encoder_udp(array_dict, data_idx, block_len, packet_idx):
    codeword, codeword_idx, bit_array_len = codeword_generate(array_dict, data_idx)
    udp_packet = packet_diffusion(codeword, block_len, packet_idx)
    return udp_packet, codeword_idx, bit_array_len



def packet_aggregation(udp_packet, packet_idx, block_len, data_idx, freeze_idx, codeword_idx, bit_array_len):
    sort_idx = [struct.unpack("I", tmp_packet[:4])[0] for tmp_packet in udp_packet]
    packet_data_del = np.array([udp_packet_to_numpy_array(tmp_packet[4:]) for _, tmp_packet in sorted(zip(sort_idx, udp_packet))])
    packet_data = np.ones((int(1024/block_len), len(packet_data_del[0])))*0.5
    for i, tmp_idx in enumerate(sorted(sort_idx)):
        packet_data[tmp_idx] = packet_data_del[i]
    
    restore_codeword = []
    inverse_packet_idx = np.argsort(packet_idx)
    for i in range(0, packet_data.shape[1],block_len):
        tmp_codeword = packet_data[:,i:i+block_len].flatten()
        restore_codeword.append(tmp_codeword[inverse_packet_idx])

    decode_partial = partial(decoding, freeze_idx=freeze_idx, data_idx=data_idx)
    with ProcessPoolExecutor() as executor:
        decoding_data = np.array(list(executor.map(decode_partial, restore_codeword)),dtype=np.int8)
    del executor, decode_partial

    restore_array = []
    for i, array_len in enumerate(bit_array_len):
        tmp_array = np.concatenate(decoding_data[codeword_idx[i]:codeword_idx[i+1]])[:array_len]
        restore_array.append(bit_to_numpy(tmp_array))
    return restore_array


def decoding(bit_array, freeze_idx, data_idx):
    # Prepare the necessary arrays and values
    bit_array = 1-2*bit_array
    lr0 = np.exp(-(bit_array - 1)**2)
    lr1 = np.exp(-(bit_array + 1)**2)
    lr0_post = lr0 / (lr0 + lr1)
    lr1_post = lr1 / (lr0 + lr1)
    delete_num = 1024 - len(bit_array)
    hd_dec = np.zeros(1024, dtype=np.float64)
    frozen_val = np.zeros(len(freeze_idx), dtype=np.float64)
    pro_prun = np.zeros((1, 2 * 1024 + 1), dtype=np.float64)

    # Call the optimized Cython function
    i_scen_sum, hd_dec_result = cython_sc_decoding(
        lr0_post, lr1_post, freeze_idx.astype(np.float64),
        hd_dec, 1024, 10, 512, frozen_val, delete_num, 0, pro_prun
    )

    # Extract the output for data_idx from hd_dec_result
    data_out = hd_dec_result[data_idx]
    return data_out


In [None]:
# Export all weights to numpy arrays
weights_dict = {name: param.cpu().detach().numpy() for name, param in net[0].state_dict().items()}
N = 1024
n = 10
rate = 0.5
K = round(N*rate)
c_1024 = np.load('c_1024.npy')
coding_list = scipy.io.loadmat("1024-3db-d=2-mean.mat")["count_number"]
coding_index = np.argsort(coding_list[:,1])
info_idx = coding_index[:K]
freeze_idx = coding_index[K:]

# sort the final index
info_ni = np.sort(info_idx)
freeze_ni = np.sort(freeze_idx)

udp_packet, codeword_idx, bit_array_len = encoder_udp(weights_dict, info_ni, 8, c_1024)

In [None]:
len(udp_packet)

In [None]:
recv_packet = {i:[] for i in range(num_nodes)}
for tmp_id in range(num_nodes):
    for i in range(len(udp_packet)):
        tmp_packet = struct.pack('I',0) + udp_packet[i]
        node_s[0].sendall(struct.pack('I',len(tmp_packet))+tmp_packet)
        if (i+1) % 16 == 0:
            try:
                while True:
                    server_r.settimeout(3)
                    data = server_r.recv(len(tmp_packet))
                    server_r.settimeout(0.5)
                    node_id = struct.unpack('I',data[:4])[0]
                    recv_packet[node_id].append(data[4:])
                    # recv_packet.append(data)
            except socket.timeout:
                # print('Timeout')
                # print(len(recv_packet[0]))
                server_r.settimeout(None)

In [None]:
restored_array = packet_aggregation(recv_packet[0], c_1024, 8, info_ni, freeze_ni, codeword_idx, bit_array_len)

In [None]:
restored_dict = {}
# restored_array = [torch.tensor(arr).to(device) for arr in restored_array]
for i, name in enumerate(weights_dict):
    restored_dict[name] = torch.tensor(restored_array[i].reshape(weights_dict[name].shape)).to(device)

In [None]:
net[1].load_state_dict(restored_dict)
for i in range(20):
    train_model(net[1], train_loader[1], criterion[1], device, scaler[1], optimizer[1], i)
    evaluate_model(net[1], test_loader[1], criterion[1], device)
    scheduler[1].step()

In [None]:
net[2].load_state_dict(restored_dict)
for i in range(20):
    train_model(net[2], train_loader[0], criterion[2], device, scaler[2], optimizer[2], i)
    evaluate_model(net[2], test_loader[0], criterion[2], device)
    scheduler[2].step()

In [None]:
for i in range(len(restored_array)):
    assert np.array_equal(restored_array[i], list(weights_dict.values())[i].flatten())
# np.array_equal(restored_array[1], list(weights_dict.values())[1].flatten())

In [None]:
list(weights_dict.values())[1].flatten()