<a href="https://colab.research.google.com/github/Henil21/federated-learning-distilbert-Quantized/blob/main/federated_lr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install flask flask-cors
!wget -q https://github.com/ekzhang/bore/releases/download/v0.5.0/bore-v0.5.0-x86_64-unknown-linux-musl.tar.gz
!tar -xzf bore-v0.5.0-x86_64-unknown-linux-musl.tar.gz
!chmod +x bore
!mv bore /usr/local/bin/





##Fedrated Learning fine tuning

In [2]:
!pip install transformers datasets torch -q

In [3]:
import numpy as np
import torch
from flask import Flask, request, jsonify
import pickle
import base64
import threading
import subprocess
import time
import os
import gzip
RANDOM_SEED = 42
torch.manual_seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)

num_clients    = 2
current_round  = 0
max_rounds     = 1
client_weights = {}

app = Flask(__name__)
app.config['MAX_CONTENT_LENGTH'] = 512 * 1024 * 1024  # 512 MB

def load_distilbert():
    from transformers import DistilBertForQuestionAnswering
    model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')
    model.eval()
    return model

print(" Initializing DistilBERT for Question Answering...")
global_model = load_distilbert()
print(f" Model ready — {sum(p.numel() for p in global_model.parameters()):,} parameters")

def get_model_weights():
    """Extract state dict as numpy arrays — no HuggingFace classes involved"""
    return {k: v.cpu().detach().numpy() for k, v in global_model.state_dict().items()}

def set_model_weights(weights_dict):
    """Load numpy arrays back into model"""
    state_dict = {k: torch.tensor(v) for k, v in weights_dict.items()}
    global_model.load_state_dict(state_dict, strict=True)

global_weights = get_model_weights()
print("Global weights extracted!")

def quantize_weights(weights_dict):
    """Quantize float32 weight dict → INT16"""
    quantized = {}
    for key, w in weights_dict.items():
        w_flat = w.flatten().astype(np.float32)
        w_min  = float(np.min(w_flat))
        w_max  = float(np.max(w_flat))

        if abs(w_max - w_min) < 1e-8:
            scale = 1.0
            zp    = 0.0
            q     = np.zeros_like(w_flat, dtype=np.int16)
        else:
            scale = (w_max - w_min) / 65535.0
            zp    = w_min
            q     = np.round((w_flat - w_min) / scale).astype(np.int32) - 32768
            q     = np.clip(q, -32768, 32767).astype(np.int16)

        quantized[key] = {
            'quantized'  : base64.b64encode(q.tobytes()).decode('utf-8'),
            'scale'      : float(scale),
            'zero_point' : float(zp),
            'shape'      : list(w.shape),
        }
    return quantized

def dequantize_weights(quantized_dict):
    """Dequantize INT16 weight dict → float32 numpy arrays"""
    weights = {}
    for key, item in quantized_dict.items():
        q_bytes  = base64.b64decode(item['quantized'])
        q_array  = np.frombuffer(q_bytes, dtype=np.int16).copy()
        scale    = float(item['scale'])
        zp       = float(item['zero_point'])
        shape    = tuple(item['shape'])
        dq       = (q_array.astype(np.float32) + 32768) * scale + zp
        weights[key] = dq.reshape(shape)
    return weights

def aggregate_weights():
    global global_weights, client_weights

    n        = len(client_weights)
    averaged = {k: np.zeros_like(v) for k, v in global_weights.items()}

    for weights in client_weights.values():
        for k in averaged:
            averaged[k] += weights[k]

    global_weights = {k: v / n for k, v in averaged.items()}
    set_model_weights(global_weights)
    print("Global model updated with averaged weights")

@app.route('/get_weights', methods=['GET'])
def get_weights():

    quantized  = quantize_weights(global_weights)
    pickled    = pickle.dumps(quantized)
    compressed = gzip.compress(pickled, compresslevel=6)
    encoded    = base64.b64encode(compressed).decode('utf-8')

    return jsonify({
        'weights' : encoded,
        'round'   : current_round,
        'status'  : 'success'
    })

@app.route('/submit_weights', methods=['POST'])
def submit_weights():
    global client_weights, global_weights, current_round

    if not request.is_json:
        return jsonify({'status': 'error', 'message': 'Content-Type must be application/json'}), 415

    data = request.get_json(force=True, silent=True)
    if data is None:
        return jsonify({'status': 'error', 'message': 'Failed to parse JSON body'}), 400

    client_id       = data.get('client_id')
    encoded_weights = data.get('weights')
    client_round    = data.get('round', 0)
    original_size   = data.get('original_size_mb', 0)
    compressed_size = data.get('compressed_size_mb', 0)

    if client_id is None or encoded_weights is None:
        return jsonify({'status': 'error', 'message': 'Missing client_id or weights'}), 400

    try:
        compressed     = base64.b64decode(encoded_weights)
        pickled        = gzip.decompress(compressed)
        quantized_dict = pickle.loads(pickled)
        weights        = dequantize_weights(quantized_dict)
    except Exception as e:
        import traceback; traceback.print_exc()
        return jsonify({'status': 'error', 'message': str(e)}), 500

    client_weights[client_id] = weights

    ratio = original_size / compressed_size if compressed_size > 0 else 0
    print(f" Client {client_id} | Round {client_round} | "
          f"{original_size:.1f}MB → {compressed_size:.1f}MB ({ratio:.1f}x) | "
          f"{len(client_weights)}/{num_clients} clients")

    should_aggregate = (len(client_weights) == num_clients)

    response = jsonify({
        'status'            : 'success',
        'message'           : f'Weights received from client {client_id}',
        'clients_submitted' : len(client_weights),
        'total_clients'     : num_clients,
        'aggregating'       : should_aggregate
    })

    if should_aggregate:
        def do_aggregate():
            global global_weights, client_weights, current_round
            print(f"\n Aggregating round {current_round + 1}...")
            aggregate_weights()
            current_round += 1
            client_weights = {}
            print(f"✓ Round {current_round} complete!\n")

            if current_round >= max_rounds:
                print(" TRAINING COMPLETE!")
                save_final_model()

        threading.Thread(target=do_aggregate, daemon=True).start()

    return response

@app.route('/status', methods=['GET'])
def status():
    return jsonify({
        'status'            : 'running',
        'current_round'     : current_round,
        'max_rounds'        : max_rounds,
        'clients_submitted' : len(client_weights),
        'total_clients'     : num_clients
    })

def save_final_model():
    print("\n" + "="*60)
    print("SAVING FINAL MODEL...")
    print("="*60)

    try:
        # FIX: import inside function — keeps pickle namespace clean
        from transformers import DistilBertForQuestionAnswering

        os.makedirs('federated_qa_model', exist_ok=True)
        global_model.save_pretrained('federated_qa_model')
        print("Saved to 'federated_qa_model/'")

        torch.save(global_model.state_dict(), 'federated_qa_model.pt')
        print("Saved as 'federated_qa_model.pt'")

        total_params = sum(p.numel() for p in global_model.parameters())
        size_mb      = os.path.getsize('federated_qa_model.pt') / (1024 * 1024)
        print(f"\n Parameters : {total_params:,}")
        print(f" Size       : {size_mb:.2f} MB")


    except Exception as e:
        print(f" Error saving: {e}")
        import traceback; traceback.print_exc()

def run_flask():
    app.run(host='0.0.0.0', port=5000, threaded=True, use_reloader=False)

if __name__ == '__main__':
    subprocess.run(['pkill', '-f', 'bore'], capture_output=True)
    time.sleep(1)

    flask_thread = threading.Thread(target=run_flask, daemon=True)
    flask_thread.start()
    time.sleep(2)
    print("Flask started on port 5000\n")

    bore_process = subprocess.Popen(
        ['bore', 'local', '5000', '--to', 'bore.pub'],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT,
        text=True, bufsize=1
    )

    for line in bore_process.stdout:
        print(line.rstrip())
        if 'bore.pub:' in line:
            try:
                port = line.split('bore.pub:')[1].split()[0].strip().rstrip('/')
                url  = f"http://bore.pub:{port}"
                print("\n" + "="*60)
                print(f"SERVER URL: {url}")
                print("="*60)
                print(f"Server ready — waiting for {num_clients} clients")
                print(f"Training: {max_rounds} rounds\n")
            except Exception:
                pass

    bore_process.wait()

 Initializing DistilBERT for Question Answering...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Loading weights:   0%|          | 0/100 [00:00<?, ?it/s]

DistilBertForQuestionAnswering LOAD REPORT from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_transform.weight  | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
qa_outputs.bias         | MISSING    | 
qa_outputs.weight       | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


 Model ready — 66,364,418 parameters
Global weights extracted!
 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on all addresses (0.0.0.0)
 * Running on http://127.0.0.1:5000
 * Running on http://172.28.0.12:5000
INFO:werkzeug:[33mPress CTRL+C to quit[0m


Flask started on port 5000

[2m2026-02-13T14:45:27.576629Z[0m [32m INFO[0m [2mbore_cli::client[0m[2m:[0m connected to server [3mremote_port[0m[2m=[0m64534
[2m2026-02-13T14:45:27.577115Z[0m [32m INFO[0m [2mbore_cli::client[0m[2m:[0m listening at bore.pub:64534

SERVER URL: http://bore.pub:64534
Server ready — waiting for 2 clients
Training: 1 rounds

[2m2026-02-13T14:46:55.108356Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0ma8cd2eb3-4d86-4696-be64-29291eeb234b[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m new connection


INFO:werkzeug:127.0.0.1 - - [13/Feb/2026 14:47:17] "GET /get_weights HTTP/1.1" 200 -


[2m2026-02-13T14:47:21.872582Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0ma8cd2eb3-4d86-4696-be64-29291eeb234b[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m connection exited
[2m2026-02-13T14:47:24.741862Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0mc02e1c8f-75ea-4657-a06d-262a27a87893[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m new connection


INFO:werkzeug:127.0.0.1 - - [13/Feb/2026 14:47:40] "GET /get_weights HTTP/1.1" 200 -


[2m2026-02-13T14:47:44.765239Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0mc02e1c8f-75ea-4657-a06d-262a27a87893[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m connection exited
[2m2026-02-13T15:16:15.264620Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0m3999d346-947d-46fb-8318-87a2434a4279[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m new connection


INFO:werkzeug:127.0.0.1 - - [13/Feb/2026 15:16:27] "POST /submit_weights HTTP/1.1" 200 -


 Client 2 | Round 0 | 253.2MB → 124.0MB (2.0x) | 1/2 clients
[2m2026-02-13T15:16:28.001875Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0m3999d346-947d-46fb-8318-87a2434a4279[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m connection exited
[2m2026-02-13T15:17:42.686032Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0m3e132d93-e6d6-4156-a69d-9a208c7e9ee7[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m new connection


INFO:werkzeug:127.0.0.1 - - [13/Feb/2026 15:17:55] "POST /submit_weights HTTP/1.1" 200 -


 Client 1 | Round 0 | 253.2MB → 124.0MB (2.0x) | 2/2 clients

 Aggregating round 1...
[2m2026-02-13T15:17:55.736697Z[0m [32m INFO[0m [1mproxy[0m[1m{[0m[3mid[0m[2m=[0m3e132d93-e6d6-4156-a69d-9a208c7e9ee7[1m}[0m[2m:[0m [2mbore_cli::client[0m[2m:[0m connection exited
Global model updated with averaged weights
✓ Round 1 complete!

 TRAINING COMPLETE!

SAVING FINAL MODEL...


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

Saved to 'federated_qa_model/'
Saved as 'federated_qa_model.pt'

 Parameters : 66,364,418
 Size       : 253.20 MB


KeyboardInterrupt: 

In [5]:
import torch
import os
import numpy as np
from google.colab import files
from transformers import DistilBertForQuestionAnswering

print("Loading trained federated model...")
model = DistilBertForQuestionAnswering.from_pretrained('federated_qa_model')
model.eval()

original_size   = os.path.getsize('federated_qa_model.pt') / (1024 * 1024)
original_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded!")
print(f"Parameters : {original_params:,}")
print(f"Size: {original_size:.2f} MB\n")

results = {}

print("─"*55)
print("Dynamic INT8 Quantization")
print("─"*55)
try:
    dynamic_model = torch.quantization.quantize_dynamic(
        model,
        {torch.nn.Linear},
        dtype=torch.qint8
    )
    path = 'model_dynamicc_int8.pt'
    torch.save(dynamic_model.state_dict(), path)
    size = os.path.getsize(path) / (1024 * 1024)
    results['dynamic'] = {'path': path, 'size': size}
    print(f"Saved : {path}")
    print(f"Size  : {size:.2f} MB  ({original_size/size:.1f}x smaller)\n")
except Exception as e:
    print(f"Failed: {e}\n")
print("─"*55)
print("ONNX Export (float32)")
print("─"*55)
try:
    dummy_ids  = torch.zeros(1, 384, dtype=torch.long)
    dummy_mask = torch.zeros(1, 384, dtype=torch.long)

    torch.onnx.export(
        model,
        (dummy_ids, dummy_mask),
        'model_float32.onnx',
        input_names   = ['input_ids', 'attention_mask'],
        output_names  = ['start_logits', 'end_logits'],
        dynamic_axes  = {
            'input_ids'     : {0: 'batch', 1: 'seq'},
            'attention_mask': {0: 'batch', 1: 'seq'},
            'start_logits'  : {0: 'batch'},
            'end_logits'    : {0: 'batch'}
        },
        opset_version = 13
    )
    size = os.path.getsize('model_float32.onnx') / (1024 * 1024)
    results['onnx'] = {'path': 'model_float32.onnx', 'size': size}
    print(f"Saved : model_float32.onnx")
    print(f"Size  : {size:.2f} MB  ({original_size/size:.1f}x smaller)\n")
except Exception as e:
    print(f"Failed: {e}\n")


print("─"*55)
print("ONNX + INT8 Quantization")
print("─"*55)
try:
    from onnxruntime.quantization import quantize_dynamic, QuantType

    quantize_dynamic(
        'model_float32.onnx',
        'model_int8.onnx',
        weight_type = QuantType.QInt8
    )
    size = os.path.getsize('model_int8.onnx') / (1024 * 1024)
    results['onnx_int8'] = {'path': 'model_int8.onnx', 'size': size}
    print(f"Saved : model_int8.onnx")
    print(f"Size  : {size:.2f} MB  ({original_size/size:.1f}x smaller)\n")
except ImportError:
    print("Run: !pip install onnxruntime -q  then re-run this cell\n")
except Exception as e:
    print(f"Failed: {e}\n")


label_map = {
    'dynamic'  : 'Dynamic INT8 (PyTorch)',
    'onnx'     : 'ONNX float32',
    'onnx_int8': 'ONNX INT8',
}

print("="*55)
print("COMPRESSION SUMMARY")
print("="*55)
print(f"  {'Model':<28} {'Size':>9}  {'Reduction':>10}")
print(f"  {'-'*50}")
print(f"  {'Original (float32 .pt)':<28} {original_size:>8.2f}MB  {'1.0x':>10}")
for k, label in label_map.items():
    if k in results:
        r = results[k]
        print(f"  {label:<28} {r['size']:>8.2f}MB  {original_size/r['size']:>9.1f}x")
print("="*55)

print("\n Downloading all files...")
download_list = [
    ('federated_qa_model.pt', 'Original model'),
    *[(r['path'], label_map[k]) for k, r in results.items()]
]

for path, label in download_list:
    if os.path.exists(path):
        size = os.path.getsize(path) / (1024 * 1024)
        print(f" {label:<28} ({size:.2f} MB)")
        files.download(path)
    else:
        print(f"{label:<28} not found")

print("\n Done!")

Loading trained federated model...


Loading weights:   0%|          | 0/102 [00:00<?, ?it/s]

Model loaded!
Parameters : 66,364,418
Size: 253.20 MB

───────────────────────────────────────────────────────
Dynamic INT8 Quantization
───────────────────────────────────────────────────────


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  dynamic_model = torch.quantization.quantize_dynamic(


Saved : model_dynamicc_int8.pt
Size  : 131.72 MB  (1.9x smaller)

───────────────────────────────────────────────────────
ONNX Export (float32)
───────────────────────────────────────────────────────
Failed: No module named 'onnxscript'

───────────────────────────────────────────────────────
ONNX + INT8 Quantization
───────────────────────────────────────────────────────
Run: !pip install onnxruntime -q  then re-run this cell

COMPRESSION SUMMARY
  Model                             Size   Reduction
  --------------------------------------------------
  Original (float32 .pt)         253.20MB        1.0x
  Dynamic INT8 (PyTorch)         131.72MB        1.9x

 Downloading all files...
 Original model               (253.20 MB)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

 Dynamic INT8 (PyTorch)       (131.72 MB)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>


 Done!


In [6]:
import torch
from transformers import DistilBertForQuestionAnswering, DistilBertTokenizerFast

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

context = """
Maria owns a small bakery in downtown Portland. She opens at 6 AM every morning
and bakes fresh bread, croissants, and cookies. Her specialty is sourdough bread,
which takes 24 hours to prepare. The bakery closes at 5 PM on weekdays and
3 PM on Sundays.
"""

question = "What is Maria's specialty at the bakery?"

inputs = tokenizer(
    question,
    context,
    return_tensors='pt',
    max_length=384,
    truncation='only_second',
    padding='max_length'
)

def get_answer(model, inputs):
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        start_idx = torch.argmax(outputs.start_logits)
        end_idx = torch.argmax(outputs.end_logits)
        tokens = inputs['input_ids'][0][start_idx:end_idx + 1]
        answer = tokenizer.decode(tokens, skip_special_tokens=True)
    return answer

print("=" * 55)
print("Model 1: Original float32 (253.20 MB)")
print("=" * 55)
model_original = DistilBertForQuestionAnswering.from_pretrained('federated_qa_model')
answer = get_answer(model_original, inputs)
print(f"Question: {question}")
print(f"Answer: {answer}\n")

print("=" * 55)
print("Model 2: Dynamic INT8 (131.72 MB)")
print("=" * 55)

model_int8 = torch.quantization.quantize_dynamic(
    DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased'),
    {torch.nn.Linear},
    dtype=torch.qint8
)
model_int8.load_state_dict(
    torch.load('model_dynamic_int8.pt', map_location='cpu')
)
answer_int8 = get_answer(model_int8, inputs)
print(f"Question: {question}")
print(f"Answer: {answer_int8}\n")

print("=" * 55)
print("COMPARISON")
print("=" * 55)
print(f"{'Model':<25} {'Size':>10}  {'Answer'}")
print(f"{'-' * 53}")
print(f"{'Original float32':<25} {'253.20 MB':>10}  {answer}")
print(f"{'Dynamic INT8':<25} {'131.72 MB':>10}  {answer_int8}")
print(f"\nAnswers match: {'YES' if answer == answer_int8 else 'NO (slight difference)'}")
print("=" * 55)

Model 1: Original float32 (253.20 MB)


Loading weights:   0%|          | 0/102 [00:00<?, ?it/s]

Question: What is Maria's specialty at the bakery?
Answer: sourdough bread

Model 2: Dynamic INT8 (131.72 MB)


Loading weights:   0%|          | 0/100 [00:00<?, ?it/s]

DistilBertForQuestionAnswering LOAD REPORT from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_transform.weight  | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
qa_outputs.bias         | MISSING    | 
qa_outputs.weight       | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please

Question: What is Maria's specialty at the bakery?
Answer: sourdough bread

COMPARISON
Model                           Size  Answer
-----------------------------------------------------
Original float32           253.20 MB  sourdough bread
Dynamic INT8               131.72 MB  sourdough bread

Answers match: YES


In [8]:
import torch
import numpy as np
from transformers import DistilBertForQuestionAnswering, DistilBertTokenizerFast
from datasets import load_dataset
import collections
import string
import re

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

print("Loading SQuAD validation set...")
val_dataset = load_dataset('squad', split='validation')

val_dataset = val_dataset.select(range(500))
print(f" {len(val_dataset)} examples loaded\n")


def normalize_answer(s):
    """Lowercase, remove punctuation, articles, extra whitespace"""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)
    def remove_punctuation(text):
        return ''.join(ch for ch in text if ch not in string.punctuation)
    def white_space_fix(text):
        return ' '.join(text.split())
    return white_space_fix(remove_articles(remove_punctuation(s.lower())))

def exact_match(prediction, ground_truths):
    pred_norm = normalize_answer(prediction)
    return any(pred_norm == normalize_answer(gt) for gt in ground_truths)

def f1_score(prediction, ground_truths):
    def token_f1(pred, gt):
        pred_tokens = normalize_answer(pred).split()
        gt_tokens   = normalize_answer(gt).split()
        common      = collections.Counter(pred_tokens) & collections.Counter(gt_tokens)
        num_same    = sum(common.values())
        if num_same == 0:
            return 0.0
        precision = num_same / len(pred_tokens)
        recall    = num_same / len(gt_tokens)
        return (2 * precision * recall) / (precision + recall)
    return max(token_f1(prediction, gt) for gt in ground_truths)

def predict_answer(model, question, context):

    inputs = tokenizer(
        question,
        context,
        return_tensors  = 'pt',
        max_length      = 384,
        truncation      = 'only_second',
        padding         = 'max_length',
        return_offsets_mapping = False
    )
    model.eval()
    with torch.no_grad():
        outputs   = model(**inputs)
        start_idx = torch.argmax(outputs.start_logits)
        end_idx   = torch.argmax(outputs.end_logits)

        # Guard: end must be >= start
        if end_idx < start_idx:
            end_idx = start_idx

        tokens = inputs['input_ids'][0][start_idx : end_idx + 1]
        answer = tokenizer.decode(tokens, skip_special_tokens=True)
    return answer.strip()

def evaluate_model(model, dataset, model_name):

    print(f"\n Evaluating {model_name} on {len(dataset)} examples...")

    em_scores = []
    f1_scores = []

    for i, example in enumerate(dataset):
        question     = example['question']
        context      = example['context']
        ground_truth = example['answers']['text']   # list of valid answers

        prediction   = predict_answer(model, question, context)

        em_scores.append(1.0 if exact_match(prediction, ground_truth) else 0.0)
        f1_scores.append(f1_score(prediction, ground_truth))

        if (i + 1) % 100 == 0:
            print(f"  [{i+1}/{len(dataset)}] "
                  f"EM: {np.mean(em_scores)*100:.2f}%  "
                  f"F1: {np.mean(f1_scores)*100:.2f}%")

    final_em = np.mean(em_scores) * 100
    final_f1 = np.mean(f1_scores) * 100
    return final_em, final_f1

# MODEL 1 — Original float32

print("="*55)
print("Model 1: Original float32 (253.20 MB)")
print("="*55)
model_original = DistilBertForQuestionAnswering.from_pretrained('federated_qa_model')
em1, f1_1 = evaluate_model(model_original, val_dataset, "Original float32")
print(f" Exact Match : {em1:.2f}%")
print(f" F1 Score    : {f1_1:.2f}%")


print("\n" + "="*55)
print("Model 2: Dynamic INT8 (131.72 MB)")
print("="*55)
model_int8 = torch.quantization.quantize_dynamic(
    DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased'),
    {torch.nn.Linear},
    dtype=torch.qint8
)
model_int8.load_state_dict(
    torch.load('model_dynamic_int8.pt', map_location='cpu')
)
em2, f1_2 = evaluate_model(model_int8, val_dataset, "Dynamic INT8")
print(f"Exact Match : {em2:.2f}%")
print(f"F1 Score    : {f1_2:.2f}%")

# ========================================
# 📊 FINAL COMPARISON
# ========================================
print("\n" + "="*55)
print("FINAL COMPARISON")
print("="*55)
print(f"  {'Model':<25} {'Size':>9}  {'EM':>8}  {'F1':>8}")
print(f"  {'-'*53}")
print(f"  {'Original float32':<25} {'253.20MB':>9}  {em1:>7.2f}%  {f1_1:>7.2f}%")
print(f"  {'Dynamic INT8':<25} {'131.72MB':>9}  {em2:>7.2f}%  {f1_2:>7.2f}%")
print(f"  {'-'*53}")
print(f"  {'EM  drop':<25} {'':>9}  {em1-em2:>7.2f}%")
print(f"  {'F1  drop':<25} {'':>9}  {f1_1-f1_2:>7.2f}%")
print(f"  {'Size reduction':<25} {'1.9x':>9}")
print("="*55)

if abs(f1_1 - f1_2) < 2.0:
    print("\n Quantization maintained accuracy — F1 drop < 2%")
else:
    print(f"\n F1 dropped {f1_1-f1_2:.2f}% after quantization")


Loading SQuAD validation set...
 500 examples loaded

Model 1: Original float32 (253.20 MB)


Loading weights:   0%|          | 0/102 [00:00<?, ?it/s]


 Evaluating Original float32 on 500 examples...
  [100/500] EM: 70.00%  F1: 78.88%
  [200/500] EM: 72.00%  F1: 81.01%
  [300/500] EM: 73.33%  F1: 79.68%
  [400/500] EM: 73.50%  F1: 79.25%
  [500/500] EM: 73.20%  F1: 79.12%
 Exact Match : 73.20%
 F1 Score    : 79.12%

Model 2: Dynamic INT8 (131.72 MB)


Loading weights:   0%|          | 0/100 [00:00<?, ?it/s]

DistilBertForQuestionAnswering LOAD REPORT from: distilbert-base-uncased
Key                     | Status     | 
------------------------+------------+-
vocab_transform.weight  | UNEXPECTED | 
vocab_transform.bias    | UNEXPECTED | 
vocab_projector.bias    | UNEXPECTED | 
vocab_layer_norm.weight | UNEXPECTED | 
vocab_layer_norm.bias   | UNEXPECTED | 
qa_outputs.bias         | MISSING    | 
qa_outputs.weight       | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.
For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please


 Evaluating Dynamic INT8 on 500 examples...
  [100/500] EM: 64.00%  F1: 75.51%
  [200/500] EM: 66.00%  F1: 76.29%
  [300/500] EM: 67.33%  F1: 74.86%
  [400/500] EM: 69.00%  F1: 75.62%
  [500/500] EM: 69.00%  F1: 75.75%
Exact Match : 69.00%
F1 Score    : 75.75%

FINAL COMPARISON
  Model                          Size        EM        F1
  -----------------------------------------------------
  Original float32           253.20MB    73.20%    79.12%
  Dynamic INT8               131.72MB    69.00%    75.75%
  -----------------------------------------------------
  EM  drop                                4.20%
  F1  drop                                3.37%
  Size reduction                 1.9x

 F1 dropped 3.37% after quantization


In [9]:
!pip install nbformat



In [14]:
!pip install nbformat -q

import nbformat
from google.colab import files
import requests

# Get the notebook from Colab's internal system
# First, let's check if it's in Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# FIND YOUR NOTEBOOK - Common locations:
import os
import glob

print("Searching for your notebook...\n")

# Search in Drive
drive_notebooks = glob.glob('/content/drive/MyDrive/**/*.ipynb', recursive=True)
colab_notebooks = glob.glob('/content/drive/MyDrive/Colab Notebooks/*.ipynb')

all_notebooks = drive_notebooks + colab_notebooks
all_notebooks = list(set(all_notebooks))  # Remove duplicates

# Show all found notebooks
if all_notebooks:
    print("Found these notebooks:")
    for i, nb in enumerate(all_notebooks, 1):
        print(f"{i}. {nb}")

    # Find the one with our name
    target = [nb for nb in all_notebooks if 'federated-lr' in nb]

    if target:
        notebook_path = target[0]
        print(f"\n✓ Found your notebook: {notebook_path}")

        # Load and fix it
        with open(notebook_path, 'r') as f:
            nb = nbformat.read(f, as_version=4)

        # Remove widgets
        if 'widgets' in nb.metadata:
            del nb.metadata['widgets']
            print("✓ Widgets metadata DELETED!")

            # Save it back
            with open(notebook_path, 'w') as f:
                nbformat.write(nb, f)

            print("\n" + "="*60)
            print("✓✓✓ YOUR NOTEBOOK IN GOOGLE DRIVE IS NOW FIXED!")
            print("="*60)
            print("\nJust:")
            print("1. File → Download → Download .ipynb")
            print("2. Push to GitHub")
        else:
            print("→ No widgets found")
    else:
        print("\n❌ 'federated-lr.ipynb' not found in Drive")
        print("Please manually download it, then run the manual fix")
else:
    print("No notebooks found in Google Drive")

Mounted at /content/drive
Searching for your notebook...

Found these notebooks:
1. /content/drive/MyDrive/Colab Notebooks/Untitled2 (1).ipynb
2. /content/drive/MyDrive/Colab Notebooks/Transformer.ipynb
3. /content/drive/MyDrive/Colab Notebooks/VlM_Experiment 2.ipynb
4. /content/drive/MyDrive/Colab Notebooks/Waste2Wealth.ipynb
5. /content/drive/MyDrive/Colab Notebooks/PR.ipynb
6. /content/drive/MyDrive/Colab Notebooks/Translation.ipynb
7. /content/drive/MyDrive/Colab Notebooks/Sans_english_transformers (1).ipynb
8. /content/drive/MyDrive/Colab Notebooks/MultiModal-rag.ipynb
9. /content/drive/MyDrive/Colab Notebooks/Copy of itm_brain_adam.ipynb
10. /content/drive/MyDrive/Colab Notebooks/Cloud_genie.ipynb
11. /content/drive/MyDrive/Colab Notebooks/brain_rs.ipynb
12. /content/drive/MyDrive/Colab Notebooks/eda.ipynb
13. /content/drive/MyDrive/Colab Notebooks/Brain (1).ipynb
14. /content/drive/MyDrive/Colab Notebooks/imagebind_ipynbー.ipynb
15. /content/drive/MyDrive/Colab Notebooks/Lama3Lla