In [2]:
import sys
project_root = "d:/MachineLearning/federated_vae"
sys.path.append(project_root)

In [2]:
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import flwr
from flwr.client import Client, ClientApp, NumPyClient
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg, FedAdagrad
from flwr.simulation import run_simulation
from flwr_datasets import FederatedDataset
from flwr.common import ndarrays_to_parameters, NDArrays, Scalar, Context

from main.data.preprocess import Preprocess
from main.data.basic_dataset import RawDataset

DEVICE = torch.device("cuda")  # Try "cuda" to train on GPU
print(f"Training on {DEVICE}")
print(f"Flower {flwr.__version__} / PyTorch {torch.__version__}")



  from .autonotebook import tqdm as notebook_tqdm
2025-07-11 23:21:38,409	INFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


Training on cuda
Flower 1.19.0 / PyTorch 2.5.1+cu121


In [3]:
from main.data.basic_dataset import BasicDataset, RawDataset
dataset1 = BasicDataset(dataset_dir="../../data/20NG")
dataset2 = BasicDataset(dataset_dir="../../data/IMDB")

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543
train_size:  25000
test_size:  25000
vocab_size:  5000
average length: 94.966


In [4]:
print(dataset1.vocab)
print(dataset2.vocab)
all_vocab = dataset1.vocab
for x in dataset2.vocab:
    if x not in dataset1.vocab:
        all_vocab.append(x)



In [5]:
all_vocab.sort()
print(all_vocab)
print(len(all_vocab))

6932


In [None]:
from main.data.preprocess import Preprocess

dataset_1 = RawDataset(dataset1.train_texts, vocab = all_vocab)

loading train texts: 100%|██████████| 11314/11314 [00:01<00:00, 5717.38it/s]
parsing texts: 100%|██████████| 11314/11314 [00:01<00:00, 9494.48it/s] 


In [26]:
from typing import List
from main.utils._utils import file_utils

# get vocab from multiple datasets
def get_all_vocab(dirs: List[str]):
    all_vocab_set = set()
    for dir in dirs:
        vocab = file_utils.read_text(f'{dir}/vocab.txt')
        all_vocab_set.update(word for word in vocab)  # Sử dụng update thay vì add
    result = list(all_vocab_set)
    result.sort()
    return result

x = get_all_vocab(["../../data/20NG", "../../data/IMDB"])
print(x)
print(len(x))

6932


In [None]:
# split data from 1 datasets
def split_data(dir:str, num_split:int, vocab = None, batch_size = 200, device = "cuda") -> List[RawDataset]:
    dataset = BasicDataset(dir, batch_size=batch_size, device = device)

    train_texts = dataset.train_texts
    if vocab is None:
        vocab = dataset.vocab
        
    num_sample = int(len(train_texts) / num_split)

    datasets = []
    for i in range(num_split):
        dataset = RawDataset(train_texts[(i * num_sample) : ((i + 1) * num_sample)], vocab = vocab, device = device)
        datasets.append(dataset)
    
    return datasets

datasets = split_data("../../data/20NG", 2, vocab = x)
    

train_size:  11314
test_size:  7532
vocab_size:  5000
average length: 110.543


loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 7496.89it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 9769.95it/s] 
loading train texts: 100%|██████████| 5657/5657 [00:00<00:00, 7883.46it/s]
parsing texts: 100%|██████████| 5657/5657 [00:00<00:00, 10247.36it/s]


In [31]:
len(datasets[0].train_texts)

5657

In [3]:
from utils import get_all_vocab, split_data

x = get_all_vocab(["../../data/20NG", "../../data/IMDB"])
print(x)

