# Federated LLM Text Summarization (T5-small) — Notebook

**Master’s Thesis (Mar–May 2024) — University of Skövde**  
This notebook consolidates the project code (server, clients, and evaluation) for **federated fine-tuning** of **T5-small** using **Flower (FedAvg)**.  
It is intended for local experimentation and reproducibility.

> **Note:** The original repo organizes code under `client/`, `server/`, and `utils/`.  
> For this notebook to run, make sure a `utils/` folder with `data_utils.py` and `model_utils.py` is available in the **same directory** as this notebook (or adjust imports accordingly).


## 1. Environment Setup

Install dependencies (uncomment and run once per environment):


In [None]:
# !pip install torch transformers datasets sentencepiece rouge-score evaluate accelerate peft flwr pyyaml pydantic tqdm numpy pandas scikit-learn matplotlib flask

## 2. Dataset

- Recommended dataset (example): **medical_meadow_cord19** from Hugging Face.  
- Place your data per-client as needed, or adapt `utils/data_utils.py` to load it.

Example link (from your README):  
https://huggingface.co/datasets/medalpaca/medical_meadow_cord19

> **Privacy:** Do not place sensitive/real patient data in this notebook or repo.


## 3. Server

Runs the Flower server with a custom strategy that saves round models.


In [None]:
import flwr as fl
from flwr.server.strategy import FedAvg
from transformers import T5ForConditionalGeneration, T5Tokenizer
from utils.model_utils import get_model, save_model
import logging
import torch

# Configure logging
logging.basicConfig(level=logging.DEBUG)


class SaveModelStrategy(FedAvg):
    def aggregate_fit(self, rnd, results, failures):
        aggregated_weights = super().aggregate_fit(rnd, results, failures)
        if aggregated_weights is not None:
            model = T5ForConditionalGeneration.from_pretrained("t5-small")
            model.load_state_dict(dict(zip(model.state_dict().keys(), [torch.tensor(w) for w in aggregated_weights if w is not None])))
            save_model(model, f"model_round_{rnd}.pth")
            logging.info(f"Model for round {rnd} saved.")
        return aggregated_weights


def start_server():
    logging.info("Loading pre-trained model...")
    model = get_model()
    model_path = "model_initial.pth"
    save_model(model, model_path)
    logging.info(f"Model loaded from {model_path}")

    strategy = SaveModelStrategy(
        min_available_clients=2,
        min_fit_clients=2,
        min_eval_clients=2,
        on_fit_config_fn=lambda rnd: {"round": rnd},
    )

    logging.info("Starting Flower server...")
    try:
        fl.server.start_server(
            server_address="0.0.0.0:8080",
            config={"num_rounds": 3},
            strategy=strategy
        )
    except Exception as e:
        logging.error(f"Error starting server: {e}", exc_info=True)


if __name__ == "__main__":
    start_server()


## 4. Client

Implements a Flower NumPyClient for local training/evaluation with ROUGE metrics.


In [None]:
import flwr as fl
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer
from utils.model_utils import get_model
from utils.data_utils import preprocess_data, load_and_split_data
from rouge_score import rouge_scorer
import os
import logging
from pathlib import Path
import json

# Configure logging
logging.basicConfig(level=logging.DEBUG)


class SummarizationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


class SummarizationClient(fl.client.NumPyClient):
    def __init__(self, model, dataloader, tokenizer, client_id):
        self.model = model
        self.dataloader = dataloader
        self.tokenizer = tokenizer
        self.client_id = client_id
        self.optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
        self.scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
        self.scores_file = Path(f"client_{self.client_id}_scores.json")

    def get_parameters(self):
        params = [val.cpu().numpy() for _, val in self.model.state_dict().items()]
        logging.info(f"[CLIENT {self.client_id}] get_parameters called")
        return params

    def set_parameters(self, parameters):
        params_dict = zip(self.model.state_dict().keys(), parameters)
        state_dict = {k: torch.tensor(v, dtype=torch.float32) for k, v in params_dict}
        self.model.load_state_dict(state_dict)
        logging.info(f"[CLIENT {self.client_id}] set_parameters called")

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        self.model.train()
        total_loss = 0
        logging.info(f"[CLIENT {self.client_id}] Training started")
        try:
            for batch in self.dataloader:
                self.optimizer.zero_grad()
                logging.debug(f"[CLIENT {self.client_id}] Processing batch: {batch}")
                inputs = preprocess_data(batch, self.tokenizer)
                logging.debug(f"[CLIENT {self.client_id}] Preprocessed inputs: {inputs}")
                outputs = self.model(**inputs)
                loss = outputs.loss
                loss.backward()
                self.optimizer.step()
                total_loss += loss.item()
            logging.info(f"[CLIENT {self.client_id}] Training completed, total loss: {total_loss}")
        except Exception as e:
            logging.error(f"[CLIENT {self.client_id}] Error during training: {e}", exc_info=True)

        model_path = f"model_client_{self.client_id}.pth"
        torch.save(self.model.state_dict(), model_path)
        logging.info(f"[CLIENT {self.client_id}] Model saved as {model_path}")

        return self.get_parameters(), len(self.dataloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        total_loss = 0
        rouge1, rouge2, rougel = 0, 0, 0
        num_batches = 0
        try:
            with torch.no_grad():
                logging.info(f"[CLIENT {self.client_id}] Evaluation started")
                for batch in self.dataloader:
                    inputs = preprocess_data(batch, self.tokenizer)
                    outputs = self.model.generate(input_ids=inputs['input_ids'],
                                                  attention_mask=inputs['attention_mask'], max_length=150)
                    preds = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
                    refs = self.tokenizer.batch_decode(inputs['labels'], skip_special_tokens=True)
                    for pred, ref in zip(preds, refs):
                        scores = self.scorer.score(ref, pred)
                        rouge1 += scores['rouge1'].fmeasure
                        rouge2 += scores['rouge2'].fmeasure
                        rougel += scores['rougeL'].fmeasure
                    num_batches += 1
            avg_rouge1 = rouge1 / num_batches
            avg_rouge2 = rouge2 / num_batches
            avg_rougel = rougel / num_batches
            logging.info(
                f"[CLIENT {self.client_id}] Evaluation completed, average loss: {total_loss / num_batches}, ROUGE-1: {avg_rouge1}, ROUGE-2: {avg_rouge2}, ROUGE-L: {avg_rougel}")

            round_scores = {"round": config["round"], "rouge1": avg_rouge1, "rouge2": avg_rouge2, "rougel": avg_rougel}
            if self.scores_file.exists():
                with open(self.scores_file, "r") as f:
                    all_scores = json.load(f)
            else:
                all_scores = []

            all_scores.append(round_scores)

            with open(self.scores_file, "w") as f:
                json.dump(all_scores, f)

            return total_loss / num_batches, len(self.dataloader.dataset), {"rouge1": avg_rouge1, "rouge2": avg_rouge2,
                                                                            "rougel": avg_rougel}
        except Exception as e:
            logging.error(f"[CLIENT {self.client_id}] Error during evaluation: {e}", exc_info=True)
            return total_loss / num_batches, len(self.dataloader.dataset), {"rouge1": 0, "rouge2": 0, "rougel": 0}


def main(client_id):
    tokenizer = T5Tokenizer.from_pretrained("t5-small")
    model = get_model()

    model_path = f"model_client_{client_id}.pth"
    if os.path.exists(model_path):
        logging.info(f"[CLIENT {client_id}] Loading existing model for client {client_id}")
        model.load_state_dict(torch.load(model_path))
    else:
        logging.info(f"[CLIENT {client_id}] Training new model for client {client_id}")

    dataset = load_and_split_data(client_id)
    dataloader = DataLoader(SummarizationDataset(dataset), batch_size=16, shuffle=True)

    client = SummarizationClient(model, dataloader, tokenizer, client_id)

    logging.info(f"[CLIENT {client_id}] Client {client_id} connecting to server at localhost:8080")
    fl.client.start_numpy_client("localhost:8080", client=client)


if __name__ == "__main__":
    import sys

    client_id = int(sys.argv[1])
    main(client_id)


## 5. Evaluation of Aggregated Models

Evaluates saved aggregated models across rounds and plots ROUGE scores.


In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import T5Tokenizer
from rouge_score import rouge_scorer
from utils.data_utils import load_local_data, preprocess_data
from utils.model_utils import load_model
import logging
import json
import matplotlib.pyplot as plt

logging.basicConfig(level=logging.INFO)


class SummarizationDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


def evaluate_model(model, dataloader, tokenizer):
    model.eval()
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge1, rouge2, rougel = 0, 0, 0
    num_batches = 0
    with torch.no_grad():
        logging.info("Evaluation started")
        for batch in dataloader:
            inputs = preprocess_data(batch, tokenizer)
            outputs = model.generate(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
            preds = tokenizer.batch_decode(outputs, skip_special_tokens=True)
            refs = tokenizer.batch_decode(inputs['labels'], skip_special_tokens=True)
            for pred, ref in zip(preds, refs):
                scores = scorer.score(ref, pred)
                rouge1 += scores['rouge1'].fmeasure
                rouge2 += scores['rouge2'].fmeasure
                rougel += scores['rougeL'].fmeasure
            num_batches += 1
    avg_rouge1 = rouge1 / num_batches
    avg_rouge2 = rouge2 / num_batches
    avg_rougel = rougel / num_batches
    logging.info(
        f"Evaluation completed, ROUGE-1: {avg_rouge1}, ROUGE-2: {avg_rouge2}, ROUGE-L: {avg_rougel}")

    return {"rouge1": avg_rouge1, "rouge2": avg_rouge2, "rougel": avg_rougel}


def main():
    model_name = 't5-small'
    tokenizer = T5Tokenizer.from_pretrained(model_name)

    test_data_file = "dataset/medical_dataset.json"
    test_data = load_local_data(test_data_file)
    test_dataloader = DataLoader(SummarizationDataset(test_data), batch_size=16, shuffle=False)

    round_scores = []

    for round_num in range(1, 4):  # Assuming 3 rounds, adjust if needed
        model_path = f"aggregated_model_round_{round_num}.pth"
        model = load_model(model_path, model_name)
        scores = evaluate_model(model, test_dataloader, tokenizer)
        scores["round"] = round_num
        round_scores.append(scores)

    with open("aggregated_model_scores.json", "w") as f:
        json.dump(round_scores, f)

    rounds = [score["round"] for score in round_scores]
    rouge1 = [score["rouge1"] for score in round_scores]
    rouge2 = [score["rouge2"] for score in round_scores]
    rougel = [score["rougel"] for score in round_scores]

    plt.figure(figsize=(10, 6))
    plt.plot(rounds, rouge1, label="ROUGE-1", marker='o')
    plt.plot(rounds, rouge2, label="ROUGE-2", marker='o')
    plt.plot(rounds, rougel, label="ROUGE-L", marker='o')
    plt.xlabel("Round")
    plt.ylabel("ROUGE Score")
    plt.title("ROUGE Scores of Aggregated Model")
    plt.legend()
    plt.grid(True)
    plt.savefig("aggregated_model_rouge_scores.png")
    plt.show()


if __name__ == "__main__":
    main()


## 6. How to Run (Notebook workflow)

Because federated training requires **concurrent** processes (server + multiple clients), the simplest flow is:

1. **Start the server** by running the server cell (it will block).  
   - Tip: Run it in one Jupyter **kernel** or start it externally with `python -m server.server`.
2. **Start clients** in **separate** kernels/terminals by running the client cell and passing a `client_id`
   (e.g., in a terminal: `python -m client.client 0`, `python -m client.client 1`, etc.).
3. After training rounds complete, **run the evaluation** cell to compute ROUGE and plot results.

If staying entirely in notebooks, you can also:

- Convert the client code into a function `run_client(client_id)` and call it in separate notebook tabs or background sessions.
- Or launch background processes using Python's `subprocess` from a cell (advanced).

> **Web app (optional):** Your README mentions Flask apps under `web/` — you can add those cells similarly if needed.


## 7. Original README (Reference)

```
Federated Text Summarization using T5 Model

This project explores the application of federated learning for fine-tuning a large language model (T5-small) for text summarization in the healthcare sector.
The objective is to evaluate whether federated learning can enhance model performance while preserving data privacy across multiple decentralized clients.


--Table of Contents--
>Introduction
>Project Structure
>Setup and Installation
>Dataset
>Running the Project
>Technologies Used



>Introduction 

This project aims to fine-tune a T5 model for text summarization using federated learning, ensuring data privacy and security in the healthcare sector.
The federated learning framework Flower is used to manage the communication and aggregation of model updates across multiple decentralized clients


Project Structure
federated_text_summarization_T5/
├── client/
│   ├── client.py
├── server/
│   ├── server.py
├── utils/
│   ├── data_utils.py
│   ├── model_utils.py
│   ├── __init__.py
├── dataset/
│   ├── medical_meadow_cord19.json
├── web/
│   ├── app.py
│   ├── static/
│   │   ├── script.js
│   │   ├── style.css
│   ├── templates/
│   │   ├── index.html
│──plot_decentralized_rouge_scores.py
├──evaluate_aggregated_model.py
├── README.md
└── requirements.txt

> Setup and Installation<

pip install -r requirements.txt

>Dataset
> 
Download the dataset: https://huggingface.co/datasets/medalpaca/medical_meadow_cord19?row=13


>Running the Project

Run each command separate terminals


Create a virtual environment:
1.python -m venv env
source env/bin/activate  # On Windows: `env\Scripts\activate`
   

Starting the server :
1.python -m server.server


Starting the Clients
2.python -m client.client 0
3.python -m client.client 1
4.python -m client.client 2

ROUGE Score Evaluation
5.python -m evaluate_aggregated_model

Starting the web Application
6.python -m web.app
7.python -m web.app
8.python -m web.app



Running the Web Application

Access the web interface at http://127.0.0.1:5000.
Access the web interface at http://127.0.0.1:5001.
Access the web interface at http://127.0.0.1:5002.


>Technologies Used

PyTorch: For model training and optimization.
Flower: Federated learning framework to manage client-server communication.
Hugging Face Transformers: For using the T5 model.
Flask: To deploy the web application.
PyCharm IDE: For development and debugging.
```