In [32]:
import torch
import copy
import shutil
import logging
import subprocess
import subprocess
import concurrent.futures
import os
from gpt import GPTLanguageModel
logger = logging.getLogger(__name__)

logging.basicConfig(filename="../federate.log", level = logging.INFO)

num_rounds = 50

In [33]:
# Load initial global model
global_model = torch.load('../models/global_model.pth')

In [34]:
# Alternatively, initialize with an average of edge models
edge_model_movies = torch.load('../models/edge_model_movies.pth')
edge_model_scientific = torch.load('../models/edge_model_scientific_papers.pth')

In [35]:
def send_global_model_to_clients(model):
    # copy global_model from global to clients
    source = "../global/global_model.pth"
    destination = ['../clients/movie_client/', '../clients/research_client/']

    for dst in destination:
        logging.info(f'Copying global_model from global to {dst}: INPROGRESS')
        shutil.copy(source, dst)
        logging.info(f'Copying global_model from global to {dst}: DONE')

In [36]:
def run_client(c, gpu_id):
    # Set the CUDA_VISIBLE_DEVICES environment variable to specify the GPU
    env = {"CUDA_VISIBLE_DEVICES": str(gpu_id)}
    command = ["python3", "gpt.py"]
    subprocess.run(command, cwd=c, env={**env, **os.environ})

def run_client_local_models():
    clients = ['movie_client', 'research_client']
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Map the run_client function to the clients and assign a different GPU to each
        futures = [executor.submit(run_client, c, i) for i, c in enumerate(clients)]
        
        # Wait for all futures to complete
        concurrent.futures.wait(futures)
    client_training = True
    

In [37]:
def federated_averaging(global_model, client_models):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Get the state dict of the global model
    model = GPTLanguageModel()
    model.load_state_dict(torch.load(global_model))
    model.to(device)
    global_state_dict = model.state_dict()

    # Initialize an empty dictionary to hold the averaged weights
    averaged_state_dict = {key: torch.zeros_like(value) for key, value in global_state_dict.items()}

    # Sum the weights from each client model
    for client_model in client_models:

        client_model = GPTLanguageModel()
        client_model.load_state_dict(torch.load(global_model))
        client_model.to(device)
        client_state_dict = client_model.state_dict()

        for key in averaged_state_dict.keys():
            averaged_state_dict[key] += client_state_dict[key]

    # Average the weights
    num_clients = len(client_models)
    for key in averaged_state_dict.keys():
        averaged_state_dict[key] /= num_clients

    # Load the averaged weights into the global model
    new_global_model = GPTLanguageModel()
    new_global_model.load_state_dict(averaged_state_dict)

    return global_model

In [38]:
def clear_files(files):
    for file in files:
        try:
            os.remove(file)
            logging.info("Removed file from models")
        except Exception as e:
            logging.info(f"EXCEPTION: {e}")

In [41]:
global_model_path = '../models/global_model.pth'
client_model_paths = ['../models/edge_model_movies.pth', '../models/edge_model_scientific_papers.pth']

for round in range(num_rounds):
    print(1)
    client_training = False
    send_global_model_to_clients(global_model_path)
    print(2)
    run_client_local_models()
    print(3)
    if client_training:
        global_model = federated_averaging(global_model_path, client_model_paths)
        model_path = '../models/global_model.pth'
        torch.save(global_model.state_dict(), model_path)
        logging.info(f"Model saved as {model_path}")
    logging.info(f"FEDERATE: Round {round} complete.")
    clear_files(client_model_paths)
    logging.info("Removed client models after fedavg.")


1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
1
2
3
