In [None]:
!pip install flwr transformers torch torchvision accelerate
!pip install --upgrade flwr

In [None]:
import flwr as fl
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
from huggingface_hub import login
login(new_session=False)

In [None]:
from huggingface_hub import notebook_login

notebook_login()

In [None]:
#MODEL_NAME = "google/gemma-2-2b-it"
#tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

In [None]:
# Loading the datasets




In [None]:
class GemmaFLClient(fl.client.NumPyClient):
    def __init__(self, model, dataloader):
        self.model = model
        self.dataloader = dataloader
        self.tokenizer = tokenizer
        self.model.train()

    def get_parameters(self, config):
        return [val.cpu().numpy() for val in self.model.parameters()]

    def set_parameters(self, parameters):
        for p, new_val in zip(self.model.parameters(), parameters):
            p.data = torch.tensor(new_val).to(p.device)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-5)

        for input_text, target_text in self.dataloader:
            inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
            targets = self.tokenizer(target_text, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
            outputs = self.model(**inputs, labels=targets)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        return self.get_parameters(config), len(self.dataloader.dataset), {}

def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        self.model.eval()
        total_loss = 0
        with torch.no_grad():
            for input_text, target_text in self.dataloader:
                inputs = self.tokenizer(input_text, return_tensors="pt", padding=True, truncation=True).to(self.model.device)
                targets = self.tokenizer(target_text, return_tensors="pt", padding=True, truncation=True).input_ids.to(self.model.device)
                outputs = self.model(**inputs, labels=targets)
                total_loss += outputs.loss.item()

        return total_loss / len(self.dataloader), len(self.dataloader.dataset), {}

In [None]:
def client_fn(cid: str):
    lang_map = {"0": "hindi", "1": "marathi", "2": "gujarati", "3": "bengali"}
    lang = lang_map[cid]
    model = load_model()
    data_loader = load_dataset(lang)
    return GemmaFLClient(model, data_loader)


In [None]:
strategy = fl.server.strategy.FedAvg(
    fraction_fit=1.0,
    min_fit_clients=4,
    min_available_clients=4
)

In [None]:
fl.simulation.start_simulation(
    client_fn=client_fn,
    num_clients=4,
    config=fl.simulation.ClientSimulationConfig(num_clients=4),
    strategy=strategy
)
