In [5]:
# ==============================================================================
# Step 1: Install and Import Necessary Libraries
# ==============================================================================
!pip install datasets sentence-transformers torch

import time
import numpy as np
import pandas as pd
import torch
import os # Added for environment variable
from torch.utils.data import DataLoader
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, InputExample, losses, models
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from copy import deepcopy

print(f"Setup complete. Using PyTorch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Setup complete. Using PyTorch version: 2.6.0+cu124
Using device: cuda


In [6]:
# ==============================================================================
# Step 2: Data Loading and Preparation
# ==============================================================================

def load_and_prepare_data(num_samples=10000):
    """Loads datasets and prepares them for the Siamese network."""

    # --- Load Datasets ---
    print(f"Loading English (PAWS) dataset ({num_samples} samples)...")
    en_dataset = load_dataset("paws", "labeled_final", split=f'train[:{num_samples}]')
    en_df = en_dataset.to_pandas()
    en_train_df = en_df.sample(frac=0.8, random_state=42)
    en_test_df = en_df.drop(en_train_df.index)
    print(f"English data loaded. Train: {len(en_train_df)}, Test: {len(en_test_df)}")

    print(f"Loading German (PAWS-X) dataset ({num_samples} samples)...")
    de_dataset = load_dataset("paws-x", "de", split=f'train[:{num_samples}]')
    de_df = de_dataset.to_pandas()
    de_train_df = de_df.sample(frac=0.8, random_state=42)
    de_test_df = de_df.drop(de_train_df.index)
    print(f"German data loaded. Train: {len(de_train_df)}, Test: {len(de_test_df)}")

    # --- Convert to InputExample format ---
    def df_to_examples(df, is_test=False):
        examples = []
        for index, row in df.iterrows():
            label = float(row['label'])
            if is_test:
                 label = label / 1.0
            examples.append(InputExample(texts=[row['sentence1'], row['sentence2']], label=label))
        return examples

    en_train_examples = df_to_examples(en_train_df)
    de_train_examples = df_to_examples(de_train_df)

    en_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(df_to_examples(en_test_df, is_test=True), name='en-test')
    de_evaluator = EmbeddingSimilarityEvaluator.from_input_examples(df_to_examples(de_test_df, is_test=True), name='de-test')

    return en_train_examples, de_train_examples, en_evaluator, de_evaluator

# Execute the function
en_train_examples, de_train_examples, en_evaluator, de_evaluator = load_and_prepare_data()

Loading English (PAWS) dataset (10000 samples)...
English data loaded. Train: 8000, Test: 2000
Loading German (PAWS-X) dataset (10000 samples)...
German data loaded. Train: 8000, Test: 2000


In [10]:
# ==============================================================================
# Step 3: Main Execution Block (Deep Learning Model)
# ==============================================================================
# Disable wandb using an environment variable before training
os.environ["WANDB_DISABLED"] = "true"

if __name__ == '__main__':
    # --- Configuration ---
    NUM_ROUNDS = 10
    LOCAL_EPOCHS = 1
    BATCH_SIZE = 16

    # --- Global Model Initialization ---
    print("Initializing the global model (SentenceTransformer)...")
    word_embedding_model = models.Transformer('distilbert-base-multilingual-cased')
    pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
    global_model = SentenceTransformer(modules=[word_embedding_model, pooling_model], device=device)

    loss_func = losses.CosineSimilarityLoss(model=global_model)
    print("Global model initialized.")

    # --- Federated Learning Loop ---
    client_train_examples = {'English': en_train_examples, 'German': de_train_examples}
    client_evaluators = {'English': en_evaluator, 'German': de_evaluator}

    for round_num in range(1, NUM_ROUNDS + 1):
        print(f"\n{'='*20} ROUND {round_num}/{NUM_ROUNDS} {'='*20}")

        local_weights = []
        client_samples = []

        for client_id, train_examples in client_train_examples.items():
            print(f"[Client {client_id}] Starting training...")

            local_model = deepcopy(global_model).to(device)
            train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=BATCH_SIZE)

            local_model.fit(train_objectives=[(train_dataloader, loss_func)],
                            epochs=LOCAL_EPOCHS,
                            warmup_steps=100,
                            show_progress_bar=True)

            local_weights.append(local_model.state_dict())
            client_samples.append(len(train_examples))
            print(f"[Client {client_id}] Training finished.")

        # --- Federated Averaging (Aggregation for PyTorch) ---
        print("\n[Server] Aggregating model updates from clients...")
        total_samples = sum(client_samples)
        global_weights = global_model.state_dict()

        for key in global_weights.keys():
            global_weights[key] = torch.zeros_like(global_weights[key])

        for i, client_weight in enumerate(local_weights):
            weight = client_samples[i] / total_samples
            for key in global_weights.keys():
                global_weights[key] += client_weight[key] * weight

        global_model.load_state_dict(global_weights)
        print("[Server] Aggregation complete.")

        # --- Round-wise Evaluation ---
        print(f"\n--- Evaluating Global Model at Round {round_num} ---")
        for client_id, evaluator in client_evaluators.items():
            score_dict = evaluator(global_model, output_path=".")

            # --- DEBUGGING STEP ---
            # Let's print the dictionary to see what keys are available
            print(f"DEBUG: Full score dictionary for {client_id}: {score_dict}")

            # FIX: The correct key is different in some library versions.
            # Let's try the most common variations.
            key_to_try1 = f"{evaluator.name}_spearman_cosine"
            key_to_try2 = f"{evaluator.name}_cosine_spearman" # Another common variation

            if key_to_try1 in score_dict:
                score = score_dict[key_to_try1]
            elif key_to_try2 in score_dict:
                score = score_dict[key_to_try2]
            else:
                print(f"ERROR: Could not find the correct score key in the dictionary for {client_id}.")
                continue # Skip to the next client if key is not found

            print(f"  Results for {client_id}: Spearman Correlation = {score:.4f}")

    print(f"\n{'='*20} FEDERATED TRAINING COMPLETE {'='*20}")

Initializing the global model (SentenceTransformer)...
Global model initialized.

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.2761


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.2873


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 1 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.33951562207226677, 'en-test_spearman_cosine': 0.35507744815920045}
  Results for English: Spearman Correlation = 0.3551
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.3589930564210201, 'de-test_spearman_cosine': 0.3914231518531139}
  Results for German: Spearman Correlation = 0.3914

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.1832


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.1951


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 2 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.3918300301849513, 'en-test_spearman_cosine': 0.40629421008396926}
  Results for English: Spearman Correlation = 0.4063
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.39827760034012344, 'de-test_spearman_cosine': 0.4296892429432385}
  Results for German: Spearman Correlation = 0.4297

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.1322


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.141


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 3 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.41158857658468606, 'en-test_spearman_cosine': 0.42171807673414596}
  Results for English: Spearman Correlation = 0.4217
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.4127906837697878, 'de-test_spearman_cosine': 0.44301244401985224}
  Results for German: Spearman Correlation = 0.4430

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0944


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.1009


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 4 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.4168041030728784, 'en-test_spearman_cosine': 0.4234762600387492}
  Results for English: Spearman Correlation = 0.4235
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.4188942380613436, 'de-test_spearman_cosine': 0.4521898833411043}
  Results for German: Spearman Correlation = 0.4522

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.073


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0772


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 5 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.4114852154412978, 'en-test_spearman_cosine': 0.4167398884601001}
  Results for English: Spearman Correlation = 0.4167
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.4104221103912199, 'de-test_spearman_cosine': 0.4478819292055125}
  Results for German: Spearman Correlation = 0.4479

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0601


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0644


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 6 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.41270251574301864, 'en-test_spearman_cosine': 0.4167598775374574}
  Results for English: Spearman Correlation = 0.4168
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.4050026056381179, 'de-test_spearman_cosine': 0.4422879523698723}
  Results for German: Spearman Correlation = 0.4423

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0506


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.054


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 7 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.40695291928664634, 'en-test_spearman_cosine': 0.409370813869499}
  Results for English: Spearman Correlation = 0.4094
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.40362412908074113, 'de-test_spearman_cosine': 0.44380873730459164}
  Results for German: Spearman Correlation = 0.4438

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0443


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss


Step,Training Loss
500,0.0482


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 8 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.4060442788991426, 'en-test_spearman_cosine': 0.4123222676705325}
  Results for English: Spearman Correlation = 0.4123
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.4057241872374092, 'de-test_spearman_cosine': 0.44015499495107935}
  Results for German: Spearman Correlation = 0.4402

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0404


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0435


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 9 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.40230391418403777, 'en-test_spearman_cosine': 0.4067904646889275}
  Results for English: Spearman Correlation = 0.4068
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.4018054995815382, 'de-test_spearman_cosine': 0.43766042056344007}
  Results for German: Spearman Correlation = 0.4377

[Client English] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0365


[Client English] Training finished.
[Client German] Starting training...


Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).
Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

Step,Training Loss
500,0.0372


[Client German] Training finished.

[Server] Aggregating model updates from clients...
[Server] Aggregation complete.

--- Evaluating Global Model at Round 10 ---
DEBUG: Full score dictionary for English: {'en-test_pearson_cosine': 0.4083944096812956, 'en-test_spearman_cosine': 0.4140900112015401}
  Results for English: Spearman Correlation = 0.4141
DEBUG: Full score dictionary for German: {'de-test_pearson_cosine': 0.3916819527084348, 'de-test_spearman_cosine': 0.4316759597124304}
  Results for German: Spearman Correlation = 0.4317

