# Personalized Federated Learning for Seq2Seq

This notebook implements a Federated Learning approach for Sequence-to-Sequence models, specifically focusing on personalization for different languages. It includes data handling, model definitions (Encoder-Decoder with Attention), training loops (including Federated Averaging), and evaluation metrics (BLEU score).

## Imports and Setup

In [11]:
import os
import random
import time
import math
import copy
import csv
import re
import unicodedata
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, Subset
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from functools import partial
from tqdm.auto import tqdm
import sacrebleu
import optuna

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set seed for reproducibility
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

Using device: cuda


## Data Handling
Functions and classes for loading, preprocessing, and managing the dataset.

In [12]:
SOS_token = 0
EOS_token = 1

class Lang:
    def __init__(self, name):
        self.name = name
        self.word2index = {}
        self.word2count = {}
        self.index2word = {0: "SOS", 1: "EOS"}
        self.n_words = 2  # Count SOS and EOS

    def addSentence(self, sentence):
        for word in sentence.split(' '):
            self.addWord(word)

    def addWord(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1

def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'
    )

def normalizeString(s):
    s = unicodeToAscii(s.lower().strip())
    s = re.sub(r"([.!?])", r" \1", s)
    s = re.sub(r"[^a-zA-Z.!?]+", r" ", s)
    return s

def readLangs(lang1, lang2, reverse=False):
    print("Reading lines...")

    # Read the file and split into lines
    # Assumes data is in 'data/' directory
    try:
        lines = open('data/%s-%s.txt' % (lang1, lang2), encoding='utf-8').read().strip().split('\n')
    except FileNotFoundError:
        # Fallback if file naming is reversed or different
        try:
            lines = open('data/%s-%s.txt' % (lang2, lang1), encoding='utf-8').read().strip().split('\n')
        except FileNotFoundError:
             # Handle specific cases like flores_kir or kir_test if they exist in root or data
             if lang1 == 'flores_kir':
                 lines1 = open('data/flores200_devtest_source_eng_Latn-run_Latn.txt', encoding='utf-8').read().strip().split('\n')
                 lines2 = open('data/flores200_devtest_target_eng_Latn-run_Latn.txt', encoding='utf-8').read().strip().split('\n')
                 lines = [f"{l1}\t{l2}" for l1, l2 in zip(lines1, lines2)]
             elif lang1 == 'kir_test':
                 lines = open('data/kir_test.txt', encoding='utf-8').read().strip().split('\n')
             else:
                 raise

    # Split every line into pairs and normalize
    pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]

    # Reverse pairs, make Lang instances
    if reverse:
        pairs = [list(reversed(p)) for p in pairs]
        input_lang = Lang(lang2)
        output_lang = Lang(lang1)
    else:
        input_lang = Lang(lang1)
        output_lang = Lang(lang2)

    return input_lang, output_lang, pairs

MAX_LENGTH = 20

def filterPair(p):
    return len(p[0].split(' ')) < MAX_LENGTH and \
        len(p[1].split(' ')) < MAX_LENGTH

def filterPairs(pairs):
    return [pair for pair in pairs if filterPair(pair)]

def prepareData(lang1, lang2, reverse=False):
    input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
    print("Read %s sentence pairs" % len(pairs))
    pairs = filterPairs(pairs)
    print("Trimmed to %s sentence pairs" % len(pairs))
    print("Counting words...")
    for pair in pairs:
        input_lang.addSentence(pair[0])
        output_lang.addSentence(pair[1])
    print("Counted words:")
    print(input_lang.name, input_lang.n_words)
    print(output_lang.name, output_lang.n_words)
    return input_lang, output_lang, pairs

def indexesFromSentence(lang, sentence):
    return [lang.word2index[word] for word in sentence.split(' ')]

def tensorFromSentence(lang, sentence):
    indexes = indexesFromSentence(lang, sentence)
    indexes.append(EOS_token)
    return torch.tensor(indexes, dtype=torch.long, device=device).view(1, -1)

def tensorsFromPair(pair, input_lang, output_lang):
    input_tensor = tensorFromSentence(input_lang, pair[0])
    target_tensor = tensorFromSentence(output_lang, pair[1])
    return (input_tensor, target_tensor)

def get_dataloader(batch_size, language='fra'):
    input_lang, output_lang, pairs = prepareData('eng', language, True)

    n = len(pairs)
    input_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)
    target_ids = np.zeros((n, MAX_LENGTH), dtype=np.int32)

    for idx, (inp, tgt) in enumerate(pairs):
        inp_ids = indexesFromSentence(input_lang, inp)
        tgt_ids = indexesFromSentence(output_lang, tgt)
        inp_ids.append(EOS_token)
        tgt_ids.append(EOS_token)
        input_ids[idx, :len(inp_ids)] = inp_ids
        target_ids[idx, :len(tgt_ids)] = tgt_ids

    indices_tensor = torch.arange(n)

    train_data = TensorDataset(torch.LongTensor(input_ids).to(device),
                               torch.LongTensor(target_ids).to(device),
                               indices_tensor.to(device))

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
    return input_lang, output_lang, train_dataloader, pairs

def limited_data_loader(original_dataloader, num_samples, random=True):
    dataset = original_dataloader.dataset
    assert len(dataset) >= num_samples, "The original dataset has fewer samples than requested"
    if random:
        indices = np.random.choice(len(dataset), num_samples, replace=False)
    else:
        indices = range(num_samples)
    subset = Subset(dataset, indices)
    new_dataloader = DataLoader(subset, batch_size=original_dataloader.batch_size, shuffle=False, num_workers=original_dataloader.num_workers)
    return new_dataloader

def get_pair_index(dataloader):
    index_list = []
    for batch in dataloader:
        index_list.extend(batch[2].tolist())
    return index_list

def split_dataloader(train_dataloader, ratio=0.9):
    dataset = train_dataloader.dataset
    train_size = int(ratio * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

    train_loader = DataLoader(train_dataset, batch_size=train_dataloader.batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=train_dataloader.batch_size, shuffle=False)
    return train_loader, test_loader, ratio

## Model Architecture
Encoder and Decoder definitions, including Attention mechanism.

In [13]:
class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_p=0.1):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, input):
        embedded = self.dropout(self.embedding(input))
        output, hidden = self.gru(embedded)
        return output, hidden

class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.Wa = nn.Linear(hidden_size, hidden_size)
        self.Ua = nn.Linear(hidden_size, hidden_size)
        self.Va = nn.Linear(hidden_size, 1)

    def forward(self, query, keys):
        scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))
        scores = scores.squeeze(2).unsqueeze(1)
        weights = F.softmax(scores, dim=-1)
        context = torch.bmm(weights, keys)
        return context, weights

class AttnDecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, dropout_p=0.1):
        super(AttnDecoderRNN, self).__init__()
        self.embedding = nn.Embedding(output_size, hidden_size)
        self.attention = BahdanauAttention(hidden_size)
        self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)
        self.out = nn.Linear(hidden_size, output_size)
        self.dropout = nn.Dropout(dropout_p)

    def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):
        batch_size = encoder_outputs.size(0)
        decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(SOS_token)
        decoder_hidden = encoder_hidden
        decoder_outputs = []
        attentions = []

        for i in range(MAX_LENGTH):
            decoder_output, decoder_hidden, attn_weights = self.forward_step(
                decoder_input, decoder_hidden, encoder_outputs
            )
            decoder_outputs.append(decoder_output)
            attentions.append(attn_weights)

            if target_tensor is not None:
                # Teacher forcing: Feed the target as the next input
                decoder_input = target_tensor[:, i].unsqueeze(1) # Teacher forcing
            else:
                # Without teacher forcing: use its own predictions as the next input
                _, topi = decoder_output.topk(1)
                decoder_input = topi.squeeze(-1).detach()  # detach from history as input

        decoder_outputs = torch.cat(decoder_outputs, dim=1)
        decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)
        attentions = torch.cat(attentions, dim=1)

        return decoder_outputs, decoder_hidden, attentions

    def forward_step(self, input, hidden, encoder_outputs):
        embedded =  self.dropout(self.embedding(input))
        query = hidden.permute(1, 0, 2)
        context, attn_weights = self.attention(query, encoder_outputs)
        input_gru = torch.cat((embedded, context), dim=2)
        output, hidden = self.gru(input_gru, hidden)
        output = self.out(output)
        return output, hidden, attn_weights

class EncoderMetaModel(nn.Module):
    def __init__(self, encoders_dict, decoders_dict):
        super(EncoderMetaModel, self).__init__()
        self.encoders_dict = nn.ModuleDict(encoders_dict)
        self.decoders_dict = nn.ModuleDict(decoders_dict)
        self.encoder_weights = nn.Parameter(torch.ones(len(encoders_dict)) / len(encoders_dict))
        self.decoder_weights = nn.Parameter(torch.ones(len(decoders_dict)) / len(decoders_dict))
    
    def forward(self, x):
        encoder_combined_output = 0
        for i, encoder in enumerate(self.encoders_dict.values()):
            encoder_combined_output += self.encoder_weights[i] * encoder(x)
        
        decoder_combined_output = 0
        for i, decoder in enumerate(self.decoders_dict.values()):
            decoder_combined_output += self.decoder_weights[i] * decoder(encoder_combined_output)
        
        return decoder_combined_output

## Training and Evaluation
Functions for training the model, evaluating performance, and calculating BLEU scores.

In [14]:
def train_epoch(dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion):
    total_loss = 0
    for data in dataloader:
        input_tensor, target_tensor, _ = data

        encoder_optimizer.zero_grad()
        decoder_optimizer.zero_grad()

        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)

        loss = criterion(
            decoder_outputs.view(-1, decoder_outputs.size(-1)),
            target_tensor.view(-1)
        )
        loss.backward()

        encoder_optimizer.step()
        decoder_optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

def train(train_dataloader, encoder, decoder, n_epochs, input, output, pairs, test_pairs=None, filename=None, learning_rate=0.001, print_every=100, plot_every=100):
    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every

    encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
    decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
    criterion = nn.NLLLoss()

    for epoch in range(1, n_epochs + 1):
        loss = train_epoch(train_dataloader, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion)
        print_loss_total += loss
        plot_loss_total += loss

        if epoch % print_every == 0:
            print_loss_avg = print_loss_total / print_every
            print_loss_total = 0
            
            bleu = evaluateBleu(encoder, decoder, input, output, pairs, n=200)
            test_bleu = 0
            if test_pairs is not None:
                test_bleu = evaluateBleu(encoder, decoder, input, output, test_pairs, n=min(200,len(test_pairs)))
            
            print('Epoch %d >> Loss: %s, Bleu: %s, Test Bleu: %s' % (epoch, print_loss_avg, bleu, test_bleu))
            
            if filename is not None:
                with open(filename, 'a') as f:
                    writer = csv.writer(f)
                    writer.writerow([epoch, print_loss_avg, bleu, test_bleu])

        if epoch % plot_every == 0:
            plot_loss_avg = plot_loss_total / plot_every
            plot_losses.append(plot_loss_avg)
            plot_loss_total = 0

    return encoder, decoder, test_pairs

def evaluate(encoder, decoder, sentence, input_lang, output_lang):
    with torch.no_grad():
        input_tensor = tensorFromSentence(input_lang, sentence)
        encoder_outputs, encoder_hidden = encoder(input_tensor)
        decoder_outputs, decoder_hidden, decoder_attn = decoder(encoder_outputs, encoder_hidden)

        _, topi = decoder_outputs.topk(1)
        decoded_ids = topi.squeeze()

        decoded_words = []
        for idx in decoded_ids:
            if idx.item() == EOS_token:
                decoded_words.append('<EOS>')
                break
            decoded_words.append(output_lang.index2word[idx.item()])
    return decoded_words, decoder_attn

def evaluateBleu(encoder, decoder, input_lang, output_lang, pairs, n=10, verbose=False):
    # Randomly select n pairs to evaluate
    if n > len(pairs):
        n = len(pairs)
    
    # Use random indices
    numbers = random.sample(range(len(pairs)), n)
    
    bleu_sum = 0
    sys = []
    refs = []

    for num in numbers:
        pair = pairs[num]
        output_words, _ = evaluate(encoder, decoder, pair[0], input_lang, output_lang)
        output_sentence = ' '.join(output_words).split(' ')
        if verbose:
            print('>', pair[0])
            print('=', pair[1])
            print('<', output_sentence)
            print('')
        
        # Prepare for sacrebleu
        # Remove <EOS> if present
        if '<EOS>' in output_sentence:
            output_sentence.remove('<EOS>')
        
        sys.append(" ".join(output_sentence))
        refs.append(" ".join(pair[1].split(' ')))
    
    # Calculate BLEU
    bleu = sacrebleu.corpus_bleu(sys, [refs])
    return bleu.score

def evaluateMultipleBleu(encoders, decoders, input_output_langs, n=10):
    bleu_sum = 0
    bleu_test_sum = 0
    lang_length = len(encoders)
    bleus = []
    test_bleus = []
    for i in range(lang_length):
        bleu = evaluateBleu(encoders[i], decoders[i], input_output_langs[i][0], input_output_langs[i][1], input_output_langs[i][2], n=n)
        bleu_sum += bleu
        bleus.append(bleu)
    for i in range(lang_length):
        test_bleu = evaluateBleu(encoders[i], decoders[i], input_output_langs[i][0], input_output_langs[i][1], input_output_langs[i][3], n=n)
        bleu_test_sum += test_bleu
        test_bleus.append(test_bleu)
    return (bleu_sum / lang_length), (bleu_test_sum / lang_length), bleus, test_bleus

def personalize(lang, rounds, encoder_weights=None, decoder_weights=None, sample=None, save=False, lang_info=None):
    if lang_info is None:
        input_lang, output_lang, train_dataloader, train_pairs = get_dataloader(batch_size=32, language=lang)
    else:
        input_lang, output_lang, train_dataloader, train_pairs = lang_info

    if sample is not None:
        train_dataloader = limited_data_loader(train_dataloader, sample, random=True)

    # Get test data
    test_pairs = train_pairs[int(len(train_pairs)*0.9):] # Simple 10% split for testing
    train_pairs = train_pairs[:int(len(train_pairs)*0.9)]

    hidden_size = 256
    encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)

    if encoder_weights is not None:
        encoder.load_state_dict(encoder_weights)
    if decoder_weights is not None:
        decoder.load_state_dict(decoder_weights)

    filename = None
    if save:
        is_FL = 0 if encoder_weights is None else 1
        num = 1
        filename = f"P|{lang}_{sample}-shot_FL-{is_FL}_epoch{rounds}||{num}.csv"
        while os.path.isfile(filename):
            num += 1
            filename = f"P|{lang}_{sample}-shot_FL-{is_FL}_epoch{rounds}||{num}.csv"

    train(train_dataloader, encoder, decoder, rounds, print_every=25, plot_every=5, filename=filename, input=input_lang, output=output_lang, pairs=train_pairs, test_pairs=test_pairs)
    return encoder, decoder, test_pairs

## Federated Learning Functions
Core logic for Federated Averaging.

In [15]:
class ClientUpdate(object):
    def __init__(self, train_dataloader, learning_rate, epochs, sch_flag):
        self.train_loader = train_dataloader
        self.learning_rate = learning_rate
        self.epochs = epochs
        self.sch_flag = sch_flag

    def train(self, encoder, decoder):
        criterion = nn.NLLLoss()
        encoder_optimizer = optim.Adam(encoder.parameters(), lr=self.learning_rate)
        decoder_optimizer = optim.Adam(decoder.parameters(), lr=self.learning_rate)
        
        epoch_loss = []
        num_pairs = 0
        
        for epoch in range(1, self.epochs + 1):
            batch_loss = []
            for data in self.train_loader:
                input_tensor, target_tensor, _ = data
                
                encoder_optimizer.zero_grad()
                decoder_optimizer.zero_grad()
                
                encoder_outputs, encoder_hidden = encoder(input_tensor)
                decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)
                
                loss = criterion(
                    decoder_outputs.view(-1, decoder_outputs.size(-1)),
                    target_tensor.view(-1)
                )
                loss.backward()
                
                encoder_optimizer.step()
                decoder_optimizer.step()
                
                batch_loss.append(loss.item())
                num_pairs += input_tensor.size(0)
                
            epoch_loss.append(sum(batch_loss) / len(batch_loss))
            
        return encoder.state_dict(), decoder.state_dict(), sum(epoch_loss) / len(epoch_loss), num_pairs

def training(encoders, decoders, input_output_lang, rounds, lr, ds, C, K, E, filename=None, batch_size=None, hidden_size=None, weighting=False, lexical_weights=None):
    if batch_size is None:
        batch_size = 32
        
    # Setup for Optuna
    input_lang_target, output_lang_target, target_dataloader_target, pairs_target = get_dataloader(batch_size, language='cat')
    langs = ['bem', 'kin', 'swa']

    # Global weights initialization
    global_encoder_weights = {key: value for key, value in encoders[0].state_dict().items() if 'embedding' not in key}
    global_decoder_weights = {key: value for key, value in decoders[0].state_dict().items() if 'embedding' not in key and 'out' not in key}
    
    train_loss = []
    best_bleu = 0
    best_test_bleu = 0
    start = time.time()
    
    lex_weighting = False
    if lexical_weights is not None:
        lex_weighting = True
    lexical_weights_encoder = lexical_weights
    lexical_weights_decoder = lexical_weights
    
    if filename is not None:
        with open(filename, 'a') as f:
            writer = csv.writer(f)
            writer.writerow(['Rounds', 'Learning Rate', 'Client Fraction', 'Client Number', 'Local Epochs', 'Batch Size', 'Hidden Size', 'Weighting', 'Lexical Weighting'])
            writer.writerow([rounds, lr, C, K, E, batch_size, hidden_size, weighting, lex_weighting])
    
    for curr_round in range(1, rounds + 1):
        w_encoder, w_decoder, local_loss, num_pairs = [], [], [], []
        
        m = max(int(C * K), 1)
        idxs_users = np.random.choice(range(K), m, replace=False)
        
        for k in tqdm(idxs_users):
            local_update = ClientUpdate(train_dataloader=ds[k], learning_rate=lr, epochs=E, sch_flag=False)
            
            e_og = encoders[k].state_dict()
            d_og = decoders[k].state_dict()
            e_og.update(global_encoder_weights)
            d_og.update(global_decoder_weights)
            encoders[k].load_state_dict(e_og)
            decoders[k].load_state_dict(d_og)
            
            encoder_weights, decoder_weights, loss, num = local_update.train(encoders[k], decoders[k])
            
            w_encoder.append({key: value for key, value in copy.deepcopy(encoder_weights).items() if 'embedding' not in key})
            w_decoder.append({key: value for key, value in copy.deepcopy(decoder_weights).items() if 'embedding' not in key and 'out' not in key})
            local_loss.append(copy.deepcopy(loss))
            num_pairs.append(num)
            
        # Optuna Optimization
        if curr_round % 5 == 0 and curr_round >= 25:
            # global_encoder_model, global_decoder_model, test_pairs = personalize('kir', 20, save=False, encoder_weights=global_encoder_weights, decoder_weights=global_decoder_weights, sample=None)
            objective_full = partial(objective, encoder_models = encoders, decoder_models = decoders, input_lang=input_lang_target, output_lang=output_lang_target, target_dataloader=target_dataloader_target, target_pairs=pairs_target, test_pairs=None, sources = langs, device = device)

            study = optuna.create_study(study_name="optimizing weights", direction="maximize", storage=f"sqlite:///studies/{filename[:-4]}.db", load_if_exists=True)
            study.optimize(objective_full, n_trials=10)

            # print(f"ENCODERS: {proportions_encoder}, {proportions_decoder}")
            print(f"PARAMS: {study.best_params}, VALUE: {study.best_value}")
            lexical_weights_encoder = []
            lexical_weights_decoder = []
            for i, (k,v) in enumerate(study.best_params.items()):
                if i % 2 == 0:
                    lexical_weights_encoder.append(v)
                else:
                    lexical_weights_decoder.append(v)
            lexical_weights_encoder.append(100 - sum(lexical_weights_encoder))
            lexical_weights_decoder.append(100 - sum(lexical_weights_decoder))
            
            # Enable lexical weighting after optimization
            lex_weighting = True

        # Updating global weights
        weights_avg_e = copy.deepcopy(w_encoder[0])
        weights_sum = sum(num_pairs)
        client_weights = [weight / weights_sum for weight in num_pairs]
        client_weights_encoder = lexical_weights_encoder
        client_weights_decoder = lexical_weights_decoder
        
        if weighting:
            client_weights_encoder = client_weights_decoder = client_weights
            
        if lex_weighting:
            if weighting:
                 # Placeholder for complex weighting if needed, but for now using the logic that matches Optuna output usage
                 pass
            else:
                # Normalize weights if they are not already (Optuna weights sum to 100 approx?)
                # The Optuna objective ensures they sum to 100 via logic, but here we might need to be careful.
                # The original code had:
                client_weights_encoder = [0.25 + 0.25 * (weight / sum(lexical_weights_encoder)) for weight in lexical_weights_encoder]
                client_weights_decoder = [0.25 + 0.25 * (weight / sum(lexical_weights_decoder)) for weight in lexical_weights_decoder]
        
        # print(f"Encoder:{client_weights_encoder}, Decoder:{client_weights_decoder}")

        for k in weights_avg_e.keys():
            if weighting or lex_weighting:
                weights_avg_e[k] *= client_weights_encoder[0]
                for i in range(1, len(w_encoder)):
                    weights_avg_e[k] += w_encoder[i][k] * client_weights_encoder[i]
            else:
                for i in range(1, len(w_encoder)):
                    weights_avg_e[k] += w_encoder[i][k]
                weights_avg_e[k] = torch.div(weights_avg_e[k], len(w_encoder))
            
        global_encoder_weights = weights_avg_e
        
        weights_avg_d = copy.deepcopy(w_decoder[0])
        for k in weights_avg_d.keys():
            if weighting or lex_weighting:
                weights_avg_d[k] *= client_weights_decoder[0]
                for i in range(1, len(w_decoder)):
                    weights_avg_d[k] += w_decoder[i][k] * client_weights_decoder[i]
            else:
                for i in range(1, len(w_decoder)):
                    weights_avg_d[k] += w_decoder[i][k]
                weights_avg_d[k] = torch.div(weights_avg_d[k], len(w_decoder))
            
        global_decoder_weights = weights_avg_d
        
        loss_avg = sum(local_loss) / len(local_loss)
        train_loss.append(loss_avg)
        
        bleu, test_bleu, bleu_list, test_bleu_list = evaluateMultipleBleu(encoders, decoders, input_output_lang, n=200)
        
        if best_bleu < bleu:
            best_bleu = bleu
        if best_test_bleu < test_bleu:
            best_test_bleu = test_bleu
            
        print(f"Round {curr_round} >> Loss: {loss_avg}, BLEU:{bleu}, TEST:{test_bleu}")
        
        if filename is not None:
            with open(filename, 'a') as f:
                writer = csv.writer(f)
                writer.writerow([curr_round, loss_avg, bleu, best_bleu, test_bleu, best_test_bleu, str(bleu_list), str(test_bleu_list)])
                
    end = time.time()
    print("Training Done!")
    print("Total time taken to Train: {}".format(end - start))
    
    return global_encoder_weights, global_decoder_weights

In [16]:
def objective(trial, encoder_models, decoder_models, input_lang, output_lang, target_dataloader, target_pairs, test_pairs, sources, device):
    batch_size = 32
    hidden_size = 256
    
    STEP_SIZE = 1

    encoder_weights = []
    decoder_weights = []

    # This is what we fine-tune for
    proportions_encoder = torch.ones(len(encoder_models))
    proportions_decoder = torch.ones(len(encoder_models))

    # Reset Heads
    global_encoder_model = EncoderRNN(input_lang.n_words, hidden_size).to(device)
    global_decoder_model = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
        
    # Combine Decoder and Encoders
    # Make a list of encoder and decoder weights
    for i in range(len(encoder_models)):
        encoder_weights.append({key: value for key, value in copy.deepcopy(encoder_models[i].state_dict()).items() if 'embedding' not in key})
        decoder_weights.append({key: value for key, value in copy.deepcopy(decoder_models[i].state_dict()).items() if 'embedding' not in key and 'out' not in key})
    
    weights_avg_e = copy.deepcopy(encoder_weights[0])
    weights_avg_d = copy.deepcopy(decoder_weights[0])

    # Make state_dict empty with zeroes
    for k in weights_avg_e:
        weights_avg_e[k].zero_()
    for k in weights_avg_d:
        weights_avg_d[k].zero_()
    
    upper_limit_decoder = upper_limit_encoder = 100

    for i, source_lang in enumerate(sources[:-1]):
        proportions_encoder[i] = trial.suggest_int(f"{source_lang}_encoder", 0, upper_limit_encoder, step=STEP_SIZE)
        proportions_decoder[i] = trial.suggest_int(f"{source_lang}_decoder", 0, upper_limit_decoder, step=STEP_SIZE)
        upper_limit_encoder -= proportions_encoder[i]
        upper_limit_decoder -= proportions_decoder[i]
    proportions_encoder[-1] = 100 - upper_limit_encoder
    proportions_decoder[-1] = 100 - upper_limit_decoder
        
    # Add and combine weights of all input encoders/decoders according to proportion
    for k in weights_avg_e.keys():
        for i in range(0, len(encoder_weights)):
            weights_avg_e[k] += proportions_encoder[i] * encoder_weights[i][k] / 100
    
    for k in weights_avg_d.keys():
        for i in range(0, len(decoder_weights)):
            
            weights_avg_d[k] += proportions_decoder[i] * decoder_weights[i][k] / 100

    # Aggregate all encoders and decoders to update globals
    encoder_og = global_encoder_model.state_dict()
    encoder_og.update(weights_avg_e)
    global_encoder_model.load_state_dict(encoder_og)

    decoder_og = global_decoder_model.state_dict()
    decoder_og.update(weights_avg_d)
    global_decoder_model.load_state_dict(decoder_og)

    # Train head for target lang
    global_encoder_model, global_decoder_model, test_pairs = personalize('kir', 10, save=False, encoder_weights=global_encoder_model.state_dict(), decoder_weights=global_decoder_model.state_dict(), sample=None, lang_info=(input_lang, output_lang, target_dataloader, target_pairs))

    #REPLACE TEST PAIRS
    test_bleu = evaluateBleu(global_encoder_model, global_decoder_model, input_lang, output_lang, test_pairs, n=min(100,len(test_pairs)))

    return test_bleu

## Experiment Management
Functions to setup and run experiments.

In [17]:
def setup_experiment(languages, batch_size=32, hidden_size=256):
    data_dict = {}
    encoders = {}
    decoders = {}
    input_output_lang = {}
    
    for i, lang in enumerate(languages):
        print(f"Setting up {lang}...")
        input_lang, output_lang, train_dataloader, pairs = get_dataloader(batch_size, language=lang)
        train_dataloader, test_dataloader, split = split_dataloader(train_dataloader)
        train_idx = get_pair_index(train_dataloader)
        test_idx = get_pair_index(test_dataloader)
        
        train_pairs = [pairs[i] for i in train_idx]
        test_pairs = [pairs[i] for i in test_idx]
        
        encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
        decoder = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
        
        encoders[i] = encoder
        decoders[i] = decoder
        data_dict[i] = train_dataloader
        input_output_lang[i] = (input_lang, output_lang, train_pairs, test_pairs)
        
    return encoders, decoders, data_dict, input_output_lang

def run_experiment(languages, rounds=10):
    batch_size = 32
    hidden_size = 256
    
    encoders, decoders, data_dict, input_output_lang = setup_experiment(languages, batch_size, hidden_size)
    
    # Run Federated Learning Training
    print("Starting Federated Learning...")
    training(encoders, decoders, input_output_lang, rounds, lr=0.001, ds=data_dict, C=1.0, K=len(languages), E=1, filename="experiment_log.csv", batch_size=batch_size, hidden_size=hidden_size)

## Main Execution

In [None]:
# Define languages to experiment with
langs = ['bem', 'kin', 'swa'] # Example languages

# Run the experiment
run_experiment(langs, rounds=5)

Setting up bem...
Reading lines...
Read 82370 sentence pairs
Trimmed to 71769 sentence pairs
Counting words...
Counted words:
bem 55089
eng 11061
Setting up kin...
Reading lines...
Read 55667 sentence pairs
Trimmed to 46273 sentence pairs
Counting words...
Counted words:
kin 85445
eng 17897
Setting up swa...
Reading lines...


KeyboardInterrupt: 

Exception ignored in: 'zmq.backend.cython._zmq.Frame.__dealloc__'
Traceback (most recent call last):
  File "zmq/backend/cython/_zmq.py", line 179, in zmq.backend.cython._zmq._check_rc
KeyboardInterrupt: 


Read 272543 sentence pairs
Trimmed to 242315 sentence pairs
Counting words...
Counted words:
swa 55715
eng 31494
Starting Federated Learning...
Reading lines...
Read 1375 sentence pairs
Trimmed to 1373 sentence pairs
Counting words...
Counted words:
cat 1815
eng 1464


  0%|          | 0/3 [00:00<?, ?it/s]