## Part 2: Federated Learning

You already have trained the model locally on your device with your (limited) data. We want now to extend this to FL.

Some helper functions are already implemented in *clientlib.py*:
- **wait_for_next_round**(server, last_trained_round=None, join_late_by_max=10):
Wait until a new round has started. A round is considered new if it is younger than *join_late_by_max* seconds. The parameter *last_trained_round* may be used to indicate the last round that the client has last participated (i.e., wait for a round != last_trained_round).

- **get_model_and_notify_client_started**(server, client_id):
Registers the client at the server and downloads the current model and metadata. *metadata['round']* contains the current round number.

- **upload_updated_model**(server, client_id, model, model_metadata):
Upload an updated model from the server. *model_metadata* must be the same as returned by *get_model_and_notify_client_started*.

FL specifications:
- train for **one epoch** with your local dataset in each round. Testing is not required (but may be interesting?)
- use the same hyperparameters as in part 1
- a round ends after 60s latest. If you upload your model too late, it is rejected by the server. **Do not cheat by changing model_metadata before uploading.**

**Your Task:**
- Implement the Federated Learning training on the client, for that in a loop do the following things
- 1) Wait for the next round to start using the helper function `clientlib.wait_for_next_round()`.
- 2) Download the newly averaged model from the server using the helper functions.
- 3) Train your model with your data, similar to on-device training.
- 4) Upload the model using the helper function


In [1]:
TOTAL_CLIENTS = 10  # number of participants in the lab
CLIENT_ID = 3  # between 0 and TOTAL_CLIENTS-1

In [2]:
import torch
import device_data

training_dataset = device_data.get_client_training_data(CLIENT_ID, TOTAL_CLIENTS)
train_loader = torch.utils.data.DataLoader(training_dataset, batch_size=16)
test_dataset = device_data.get_test_data()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
import clientlib
SERVER = 'federated-learning.in8.itec.kit.edu:80'

# this reaches 76% within 10 rounds with 4 clients

torch.set_num_threads(1)

import time
from tqdm import tqdm
import sys

device = 'cuda'

def correct_predictions(outputs, targets):
    correct_predictions = 0
    for i in range(len(targets)):
        correct_predictions += int(torch.argmax(outputs[i]) == targets[i])
    return correct_predictions

def train(model, optimizer, criterion, trainloader, device='cpu'):
    #-to-be-done-by-student---------
    model.train()
    model.to(device)
    #-------------------------------
    for _, (inputs, targets) in enumerate(tqdm(trainloader, ncols=80,
                                               file=sys.stdout, desc="Training", leave=False)):
        
        #-to-be-done-by-student----
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, targets)
        loss.backward()
        optimizer.step()
        #--------------------------

def test(model, testloader, device='cpu'):
    num_correct = 0
    num_samples = 0
    
    #-to-be-done-by-student---------
    model.eval()
    #-------------------------------
    for _, (inputs, targets) in enumerate(tqdm(testloader, ncols=80,
                                               file=sys.stdout, desc="Testing", leave=False)):
        #-to-be-done-by-student----
        inputs = inputs.to(device)
        outputs = model(inputs)
        num_correct += correct_predictions(outputs, targets)
        num_samples += len(inputs)
        #--------------------------
    return num_correct / num_samples

last_trained_round = None

#criterion = torch.nn.CrossEntropyLoss()
#optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

while True:
    print("Starting round")
    
    #-to-be-done-by-student-----
    clientlib.wait_for_next_round(SERVER, last_trained_round)
    model, model_metadata = clientlib.get_model_and_notify_client_started(SERVER, CLIENT_ID)
    criterion = loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    #---------------------------

    print(f'downloaded model -> train (round {model_metadata["round"]})')

    start = time.time()
    #-to-be-done-by-student-----
    train(model, optimizer, criterion, train_loader, device=device)
    #---------------------------
    end = time.time()
    print(f'training took {end-start:.1f}s')

    print('upload updated model')
    #-to-be-done-by-student-----
    clientlib.upload_updated_model(SERVER, CLIENT_ID, model, model_metadata)
    #---------------------------

    last_trained_round = model_metadata['round']
    
    #we measure the accuracy after 10 rounds
    if model_metadata['round'] % 10 == 0:
        accuracy = test(model, test_loader, device=device)
        print(f'Accuracy: ({accuracy:.2f})')
        #break

Starting round
downloaded model -> train (round 1)
training took 27.4s                                                             
upload updated model
Starting round
downloaded model -> train (round 2)
training took 9.1s                                                              
upload updated model
Starting round
downloaded model -> train (round 3)
training took 8.9s                                                              
upload updated model
Starting round
downloaded model -> train (round 4)
training took 8.9s                                                              
upload updated model
Starting round
downloaded model -> train (round 5)
training took 8.9s                                                              
upload updated model
Starting round
downloaded model -> train (round 6)
training took 9.0s                                                              
upload updated model
Starting round
downloaded model -> train (round 7)
training took 8.9s             


KeyboardInterrupt

