In [1]:
import numpy as np
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from transformers import (
    DistilBertForQuestionAnswering,
    DistilBertTokenizerFast
)
from datasets import load_dataset
import requests
import pickle
import base64
import gzip
import time

RANDOM_SEED      = 42
CLIENT_ID        = 2                         
SERVER_URL       = "http://bore.pub:64534"    

EPOCHS_PER_ROUND = 1
MAX_ROUNDS       = 1
BATCH_SIZE       = 8
LR               = 3e-5
MAX_LENGTH       = 384
DOC_STRIDE       = 128

torch.manual_seed(RANDOM_SEED + CLIENT_ID)
np.random.seed(RANDOM_SEED + CLIENT_ID)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

print("Loading SQuAD dataset...")
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
dataset   = load_dataset('squad', split='train')

# Client 2 gets SECOND half
half    = len(dataset) // 2
dataset = dataset.select(range(half, len(dataset)))   # ← second half
print(f"Client {CLIENT_ID} — {len(dataset)} examples (second half)\n")

def preprocess(examples):
    inputs = tokenizer(
        examples['question'],
        examples['context'],
        max_length                = MAX_LENGTH,
        truncation                = 'only_second',
        stride                    = DOC_STRIDE,
        return_overflowing_tokens = True,
        return_offsets_mapping    = True,
        padding                   = 'max_length'
    )

    offset_mapping  = inputs.pop('offset_mapping')
    sample_map      = inputs.pop('overflow_to_sample_mapping')
    answers         = examples['answers']
    start_positions = []
    end_positions   = []

    for i, offset in enumerate(offset_mapping):
        sample_idx   = sample_map[i]
        answer       = answers[sample_idx]
        cls_index    = inputs['input_ids'][i].index(tokenizer.cls_token_id)
        sequence_ids = inputs.sequence_ids(i)

        if len(answer['answer_start']) == 0:
            start_positions.append(cls_index)
            end_positions.append(cls_index)
            continue

        start_char = answer['answer_start'][0]
        end_char   = start_char + len(answer['text'][0])

        token_start = token_end = cls_index
        for idx, (sid, (o_start, o_end)) in enumerate(zip(sequence_ids, offset)):
            if sid != 1:
                continue
            if o_start <= start_char < o_end:
                token_start = idx
            if o_start < end_char <= o_end:
                token_end = idx

        start_positions.append(token_start)
        end_positions.append(token_end)

    inputs['start_positions'] = start_positions
    inputs['end_positions']   = end_positions
    return inputs

print("Tokenizing dataset (this takes ~2 min)...")
tokenized = dataset.map(
    preprocess,
    batched=True,
    remove_columns=dataset.column_names
)
tokenized.set_format('torch')
loader = DataLoader(tokenized, batch_size=BATCH_SIZE, shuffle=True)
print(f"{len(tokenized)} tokenized examples ready\n")

def quantize_weights(state_dict):
    quantized = {}
    for key, tensor in state_dict.items():
        w     = tensor.cpu().numpy().flatten().astype(np.float32)
        w_min = float(np.min(w))
        w_max = float(np.max(w))

        if abs(w_max - w_min) < 1e-8:
            scale = 1.0
            zp    = 0.0
            q     = np.zeros_like(w, dtype=np.int16)
        else:
            scale = (w_max - w_min) / 65535.0
            zp    = w_min
            q     = np.round((w - w_min) / scale).astype(np.int32) - 32768
            q     = np.clip(q, -32768, 32767).astype(np.int16)

        quantized[key] = {
            'quantized'  : base64.b64encode(q.tobytes()).decode('utf-8'),
            'scale'      : float(scale),
            'zero_point' : float(zp),
            'shape'      : list(tensor.shape),
        }
    return quantized

def get_size_mb(b):
    return len(b) / (1024 * 1024)


def get_global_weights():
    try:
        response = requests.get(f"{SERVER_URL}/get_weights", timeout=60)
        if response.status_code != 200:
            print(f"Server returned {response.status_code}")
            return None, None

        data       = response.json()
        compressed = base64.b64decode(data['weights'])
        pickled    = gzip.decompress(compressed)
        q_dict     = pickle.loads(pickled)

        state_dict = {}
        for key, item in q_dict.items():
            q_bytes = base64.b64decode(item['quantized'])
            q_array = np.frombuffer(q_bytes, dtype=np.int16).copy()
            dq      = (q_array.astype(np.float32) + 32768) * item['scale'] + item['zero_point']
            state_dict[key] = torch.tensor(dq.reshape(item['shape']))

        return state_dict, data['round']

    except Exception as e:
        print(f"Error getting weights: {e}")
        return None, None

def submit_weights(state_dict, round_num):
    try:
        original_size = get_size_mb(pickle.dumps(
            {k: v.cpu().numpy() for k, v in state_dict.items()}
        ))

        print("      Quantizing..")
        q_dict     = quantize_weights(state_dict)
        pickled    = pickle.dumps(q_dict)
        compressed = gzip.compress(pickled, compresslevel=6)
        encoded    = base64.b64encode(compressed).decode('utf-8')
        comp_size  = get_size_mb(compressed)

        print(f"   {original_size:.1f} MB → {comp_size:.1f} MB "
              f"({original_size/comp_size:.1f}x smaller)")

        payload = {
            'client_id'         : CLIENT_ID,
            'weights'           : encoded,
            'round'             : round_num,
            'original_size_mb'  : original_size,
            'compressed_size_mb': comp_size,
        }

        response = requests.post(
            f"{SERVER_URL}/submit_weights",
            json=payload,
            timeout=120
        )

        if response.status_code != 200:
            print(f"Server returned {response.status_code}: {response.text[:200]}")
            return None
        return response.json()

    except Exception as e:
        print(f"Error submitting weights: {e}")
        import traceback; traceback.print_exc()
        return None

model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
model.to(device)

print(f"  Federated QA Training — Client {CLIENT_ID}")
print(f"   Server : {SERVER_URL}")
print(f"   Rounds : {MAX_ROUNDS}  |  Epochs/round: {EPOCHS_PER_ROUND}\n")

for round_num in range(MAX_ROUNDS):
    print(f"\n{'='*60}")
    print(f"  ROUND {round_num+1}/{MAX_ROUNDS} — Client {CLIENT_ID}")
    print(f"{'='*60}")

    print("Downloading global model...")
    state_dict, server_round = get_global_weights()
    if state_dict is None:
        print("Failed — retrying in 10s...")
        time.sleep(10)
        continue
    model.load_state_dict(state_dict)
    print("Global model loaded")

    print(f"Training ({EPOCHS_PER_ROUND} epochs)...")
    model.train()
    optimizer = AdamW(model.parameters(), lr=LR)

    for epoch in range(EPOCHS_PER_ROUND):
        total_loss = 0
        for step, batch in enumerate(loader):
            batch    = {k: v.to(device) for k, v in batch.items()}
            outputs  = model(**batch)
            loss     = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            total_loss += loss.item()

            if step % 50 == 0:
                print(f"   Epoch {epoch+1} | Step {step}/{len(loader)} "
                      f"| Loss: {loss.item():.4f}")

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch+1} — Avg Loss: {avg_loss:.4f}")

    print("Uploading weights...")
    result = submit_weights(model.state_dict(), round_num)

    if result and result.get('status') == 'success':
        print("Weights submitted successfully")
        if result.get('aggregating'):
            print("Server aggregating...")
    else:
        print("Failed to submit weights")

    time.sleep(5)

print(f"\n Client {CLIENT_ID} training complete!")

2026-02-13 14:46:34.207311: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1770993994.639053      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770993994.767529      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770993995.681350      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770993995.681397      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770993995.681400      55 computation_placer.cc:177] computation placer alr

Device: cuda
Loading SQuAD dataset...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

plain_text/train-00000-of-00001.parquet:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

plain_text/validation-00000-of-00001.par(…):   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/87599 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/10570 [00:00<?, ? examples/s]

Client 2 — 43800 examples (second half)

Tokenizing dataset (this takes ~2 min)...


Map:   0%|          | 0/43800 [00:00<?, ? examples/s]

44319 tokenized examples ready



model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  Federated QA Training — Client 2
   Server : http://bore.pub:64534
   Rounds : 1  |  Epochs/round: 1


  ROUND 1/1 — Client 2
Downloading global model...
Global model loaded
Training (1 epochs)...
   Epoch 1 | Step 0/5540 | Loss: 5.9401
   Epoch 1 | Step 50/5540 | Loss: 3.6917
   Epoch 1 | Step 100/5540 | Loss: 3.7409
   Epoch 1 | Step 150/5540 | Loss: 4.0092
   Epoch 1 | Step 200/5540 | Loss: 2.9883
   Epoch 1 | Step 250/5540 | Loss: 1.8115
   Epoch 1 | Step 300/5540 | Loss: 2.8549
   Epoch 1 | Step 350/5540 | Loss: 2.3659
   Epoch 1 | Step 400/5540 | Loss: 1.4138
   Epoch 1 | Step 450/5540 | Loss: 3.0196
   Epoch 1 | Step 500/5540 | Loss: 2.8293
   Epoch 1 | Step 550/5540 | Loss: 3.7957
   Epoch 1 | Step 600/5540 | Loss: 1.4699
   Epoch 1 | Step 650/5540 | Loss: 3.1806
   Epoch 1 | Step 700/5540 | Loss: 2.6292
   Epoch 1 | Step 750/5540 | Loss: 1.8651
   Epoch 1 | Step 800/5540 | Loss: 2.0066
   Epoch 1 | Step 850/5540 | Loss: 2.0473
   Epoch 1 | Step 900/5540 | Loss: 2.7684
   Epo