In [2]:
import socket
import struct
import scipy.io
import numpy as np
from functools import partial
from concurrent.futures import ProcessPoolExecutor
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import DirichletPartitioner
from torchvision.transforms import ToTensor
from flwr_datasets.visualization import plot_label_distributions
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from cython_decoder import cython_sc_decoding

In [None]:
host = '127.0.0.1'
port = 5000
server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
server_socket.bind((host, port))
server_socket.listen(8)

array = np.array([1.0, 2.5, 3.5, 4.5, 5.5], dtype=np.float32)
shape = array.shape
shape_data = struct.pack('!' + 'I' * len(shape), *shape)
shape_size = len(shape_data)
array_data = array.tobytes()
array_length = len(array_data)
packet=struct.pack('!I', shape_size) + shape_data + struct.pack('!I', array_length) + array_data

In [None]:
node_s = []
node_r = []

try:
    while True:
        client_socket, addr = server_socket.accept()
        server_socket.settimeout(1)
        data = client_socket.recv(1024).decode()
        if data == "Server-R":
            server_s = client_socket
        elif data == "Server-S":
            server_r = client_socket
        elif data == "Node-R":
            node_s.append(client_socket)
        elif data == "Node-S":
            node_r.append(client_socket)
        client_socket.sendall(struct.pack('I',len(b"start"))+b"start")
except socket.timeout:
    print('Timeout')
    server_socket.settimeout(None)

for tmp_socket in node_r:
    tmp_socket.recv(1024)
server_r.recv(1024)

In [None]:
for _ in range(100):
    server_s.sendall(struct.pack('I',len(packet))+packet)

In [None]:
for tmp_socket in node_r:
    tmp_socket.close()
for tmp_socket in node_s:
    tmp_socket.close()
server_s.close()
server_r.close()

In [5]:
fds = FederatedDataset(
    dataset="cifar10",
    partitioners={
        "train": DirichletPartitioner(
            num_partitions=10,
            partition_by="label",
            alpha=0.1,
            seed=42,
            min_partition_size=0,
        ),
    },
)

In [6]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.Resize(32),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

def train_transforms(batch):
  transforms = transform_train
  batch["img"] = [transforms(img) for img in batch["img"]]
  return batch

def test_transforms(batch):
    transforms = transform_test
    batch["img"] = [transforms(img) for img in batch["img"]]
    return batch

partition = fds.load_partition(0, "train").with_transform(train_transforms)
centralized_dataset = fds.load_split("test").with_transform(test_transforms)
train_loader = DataLoader(partition, batch_size=512, shuffle=True, num_workers=16)
test_loader = DataLoader(centralized_dataset, batch_size=100, shuffle=False, num_workers=16)

In [7]:
from models.vit_small import ViT
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ViT(
    image_size = 32,
    patch_size = 4,
    num_classes = 10,
    dim = 512,
    depth = 6,
    heads = 8,
    mlp_dim = 512,
    dropout=0.1,
    emb_dropout=0.1
).to(device)


optimizer = optim.Adam(net.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 20)
criterion = nn.CrossEntropyLoss()
scaler = torch.cuda.amp.GradScaler(enabled=True)

In [8]:
def train_model(model: nn.Module, 
                train_loader: DataLoader, 
                criterion: nn.Module, 
                device: torch.device, 
                scaler: torch.cuda.amp.GradScaler, 
                optimizer: torch.optim.Optimizer,
                epoch: int):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    for batch in train_loader:
        inputs = batch["img"].to(device)
        labels = batch["label"].to(device)
        with torch.cuda.amp.autocast(enabled=True):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad()

        total_loss += loss.item()
        total_samples += labels.size(0)
        _, preds = torch.max(outputs, 1)
        total_correct += (preds == labels).sum().item()
    print(f"Epoch: {epoch},Train Loss: {total_loss / total_samples:.4f}, Train Accuracy: {total_correct / total_samples:.4f}")

def evaluate_model(model: nn.Module, 
                   test_loader: DataLoader, 
                   criterion: nn.Module, 
                   device: torch.device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for batch in test_loader:
            inputs = batch["img"].to(device)
            labels = batch["label"].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            total_samples += labels.size(0)
            _, preds = torch.max(outputs, 1)
            total_correct += (preds == labels).sum().item()
    print(f"Validation Loss: {total_loss / total_samples:.4f}, Validation Accuracy: {total_correct / total_samples:.4f}\n\t")

In [9]:
for i in range(20):
    train_model(net, train_loader, criterion, device, scaler, optimizer, i)
    evaluate_model(net, test_loader, criterion, device)
    scheduler.step()


Epoch: 0,Train Loss: 0.0029, Train Accuracy: 0.5645
Validation Loss: 0.0534, Validation Accuracy: 0.1576
	
Epoch: 1,Train Loss: 0.0020, Train Accuracy: 0.6724
Validation Loss: 0.0527, Validation Accuracy: 0.1591
	
Epoch: 2,Train Loss: 0.0018, Train Accuracy: 0.7070
Validation Loss: 0.0492, Validation Accuracy: 0.1709
	
Epoch: 3,Train Loss: 0.0017, Train Accuracy: 0.7263
Validation Loss: 0.0471, Validation Accuracy: 0.1800
	
Epoch: 4,Train Loss: 0.0016, Train Accuracy: 0.7385
Validation Loss: 0.0459, Validation Accuracy: 0.2234
	
Epoch: 5,Train Loss: 0.0016, Train Accuracy: 0.7464
Validation Loss: 0.0453, Validation Accuracy: 0.2072
	
Epoch: 6,Train Loss: 0.0015, Train Accuracy: 0.7581
Validation Loss: 0.0453, Validation Accuracy: 0.2102
	
Epoch: 7,Train Loss: 0.0014, Train Accuracy: 0.7609
Validation Loss: 0.0427, Validation Accuracy: 0.2464
	
Epoch: 8,Train Loss: 0.0014, Train Accuracy: 0.7712
Validation Loss: 0.0438, Validation Accuracy: 0.2535
	
Epoch: 9,Train Loss: 0.0014, Train Ac

In [51]:
def int_to_bit(number):
    # Ensure the input number is a NumPy int32
    int32_number = np.array([number], dtype=np.int32)
    # Convert the int32 number to a byte buffer
    byte_representation = int32_number.tobytes()
    # Convert each byte to an 8-bit binary string and concatenate them
    bit_stream = ''.join(format(byte, '08b') for byte in byte_representation)
    # Convert the bit stream to a NumPy array of floats (as in your original bit_array)
    bit_array = np.array([float(bit) for bit in bit_stream], dtype=np.float64)
    return bit_array

def bit_to_int(bit_array):
    # Convert the bit array (floats) to a string of bits
    bit_string = ''.join(str(int(bit)) for bit in bit_array)
    # Convert the bit string to an integer
    int_value = int(bit_string, 2)  # Convert binary string to an integer
    # To get the original int32 value, interpret the bits as an int32 value
    # Handle cases where the original number might be negative by using np.int32
    int32_value = np.frombuffer(int_value.to_bytes(4, byteorder='big'), dtype=np.int32)[0]
    return int32_value

def tensor_to_bit(tensor):
    if tensor.dim() == 0:
        tensor = tensor.unsqueeze(0)
    byte_tensor = tensor.view(torch.uint8).flatten()
    bit_stream = ''.join(format(byte.item(), '08b') for byte in byte_tensor)
    bit_tensor = torch.tensor([float(bit) for bit in bit_stream], dtype=torch.float64)
    return bit_tensor

def bit_to_tensor(bit_tensor):
    bit_stream = ''.join(str(int(bit.item())) for bit in bit_tensor)
    byte_tensor = torch.tensor([int(bit_stream[i:i+8], 2) for i in range(0, len(bit_stream), 8)], dtype=torch.uint8)
    float_tensor = byte_tensor.view(torch.float32)
    return float_tensor

def numpy_to_bit(array):
    tmp_byte=np.frombuffer(array.tobytes(),dtype=np.uint8)
    bit_stream = ''.join(format(byte, '08b') for byte in tmp_byte)
    bit_array = np.array([float(bit) for bit in bit_stream], dtype=np.float64)
    return bit_array

def bit_to_numpy(bit_array):
    bit_stream = ''.join(str(int(bit)) for bit in bit_array)
    byte_array_back = np.array([int(bit_stream[i:i+8], 2) for i in range(0, len(bit_stream), 8)], dtype=np.uint8)
    float_array_back = np.frombuffer(byte_array_back.tobytes(), dtype=np.float32)
    return float_array_back

'''
Encoding and decoding function
'''
def encode(u):
    N = u.shape[0]  # Get the length of u
    n = int(np.log2(N))  # Calculate the log base 2 of N

    if n == 1:
        x = np.array([(u[0] + u[1]) % 2, u[1]])
        return x
    else:
        x1 = encode(np.mod(u[:N//2] + u[N//2:], 2))
        x2 = encode(u[N//2:])
        x = np.concatenate((x1, x2))
        return x

def rvsl(y):
    N = y.shape[0]
    if N == 2:
        return y
    else:
        return np.concatenate((rvsl(y[0:N:2]), rvsl(y[1:N:2])))


def data_process(array):
    bit_array = numpy_to_bit(array)
    array_len = len(bit_array)
    for i in range(0, array_len, 512):
        sub_array = bit_array[i:i+512]
        if len(sub_array) < 512:
            padding = np.ones((512 - len(sub_array)), dtype=bit_array.dtype)
            sub_array = np.concatenate((sub_array, padding))
    current_idx = i // 512 + 1
    return sub_array, current_idx

def data_generate_torch(bit_tensor, data_idx):
    u = torch.zeros(1024, dtype=torch.float32)
    u[data_idx] = bit_tensor
    x = encode(u)
    x = rvsl(x)
    x = 1 - 2 * x
    return x

def data_generate(bit_array, data_idx):
    u=np.zeros(1024)
    u[data_idx] = bit_array
    x = encode(u)
    x = rvsl(x)
    x = 1-2*x
    return x

def decoding(bit_array, freeze_idx, data_idx):
    # Prepare the necessary arrays and values
    lr0 = np.exp(-(bit_array - 1)**2)
    lr1 = np.exp(-(bit_array + 1)**2)
    lr0_post = lr0 / (lr0 + lr1)
    lr1_post = lr1 / (lr0 + lr1)
    delete_num = 1024 - len(bit_array)
    hd_dec = np.zeros(1024, dtype=np.float64)
    frozen_val = np.zeros(len(freeze_idx), dtype=np.float64)
    pro_prun = np.zeros((1, 2 * 1024 + 1), dtype=np.float64)

    # Call the optimized Cython function
    i_scen_sum, hd_dec_result = cython_sc_decoding(
        lr0_post, lr1_post, freeze_idx.astype(np.float64),
        hd_dec, 1024, 10, 512, frozen_val, delete_num, 0, pro_prun
    )

    # Extract the output for data_idx from hd_dec_result
    data_out = hd_dec_result[data_idx]
    return data_out

In [47]:
# Export all weights to numpy arrays
weights_dict = {name: param.cpu().detach().numpy() for name, param in net.state_dict().items()}
N = 1024
n = 10
rate = 0.5
K = round(N*rate)
coding_list = scipy.io.loadmat("1024-3db-d=2-mean.mat")["count_number"]
coding_index = np.argsort(coding_list[:,1])
info_idx = coding_index[:K]
freeze_idx = coding_index[K:]

# sort the final index
info_ni = np.sort(info_idx)
freeze_ni = np.sort(freeze_idx)


In [None]:
split_bit = []
codeword_idx = []
bit_array_len = []
codeword_idx.append(0)
total_idx = 0

for tmp_key in weights_dict.keys():
    tmp_array = weights_dict[tmp_key]
    bit_array = numpy_to_bit(tmp_array)
    bit_array_len.append(len(bit_array))
    for i in range(0,len(bit_array),512):
        sub_array = bit_array[i:i+512]
        # Add 1 at the end of the array
        if len(sub_array) < 512:
            padding = np.ones((512 - len(sub_array)), dtype=bit_array.dtype)
            sub_array = np.concatenate((sub_array, padding))
        split_bit.append(sub_array)
    total_idx += i // 512 + 1
    codeword_idx.append(total_idx)


In [49]:
encode_partial = partial(data_generate, data_idx=info_ni)
with ProcessPoolExecutor() as executor:
    codeword_py = list(executor.map(encode_partial, split_bit))
