This setting is the federated class incremental settings.

In [None]:
# Set the train parameters.
from configs import configs
configs['test_batchsize']=128
configs['test_dataset_size']=1000
configs['train_dataset_size']=5000
configs['head_finetune_epoch']=3
configs['train_epoch']=8
configs['head_ft_learn_rate']=0.001
configs['learn_rate']=0.001
configs['train_batchsize']=128

number_of_clients=10
total_rounds=90

In [None]:
# Load the dataset.
from fcl_data_simulator.single_dataset import MNIST,EMNIST,USPS
mnist=MNIST()
emnist=EMNIST()
usps=USPS()

In [None]:
from fcl_data_simulator.continual_policy import GradualContiunalPolicy
from fcl_data_simulator.partition_policy import PartitioningPolicies

train_tasks=[usps["train"],mnist["train"],emnist["train"]]
test_tasks=[usps["test"],mnist["test"],emnist["test"]]
train_cp=GradualContiunalPolicy.create_by_task_durations_linear(train_tasks,
                                    [0,20,50],[40,70,90],
                                    configs['train_dataset_size'])
test_cp=GradualContiunalPolicy.create_by_task_durations_linear(test_tasks,
                                    [0,20,50],[40,70,90],
                                    configs['test_dataset_size'])                                    

train_partition_policy=PartitioningPolicies.dirichlet_nonIID_partitioning
test_partition_policy=PartitioningPolicies.IID_partitioning
partition_policy_args={"number_of_clients":number_of_clients}
test_partition_policy_args={"number_of_clients":1}

# Create the DataManager.
from fcl_data_simulator.data_manager import DataManager
train_data_manager=DataManager(configs["train_batchsize"],train_cp,
                    train_partition_policy,partition_policy_args)
test_data_manager=DataManager(configs["test_batchsize"],test_cp,
                    test_partition_policy,test_partition_policy_args)

print("The fcl data simulator is set.")

In [None]:
import torch
# Set the model.
shared_model=torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.Linear(in_features=28*28,out_features=400),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=400,out_features=200),
    torch.nn.ReLU(),
    torch.nn.Linear(in_features=200,out_features=200),
)
# The head library.
head_library={}

In [None]:
from baselines.fcl_fedavg_client import fcl_fedavg_client
# Set the clients.
clients=[fcl_fedavg_client() for _ in range(number_of_clients)]
# The client used to test.
pseudo_client=fcl_fedavg_client()

In [None]:
import evaluation
from fcl_server import fcl_server
from fcl_data_simulator.dataset_utils import sample_slice
# Now the train process.

client_accuracies=[]
'''The accuracy of clients over rounds.'''
server_accuracy=[]
'''The test accuracy of the aggregated model over rounds.'''
task_accuracies=[]
'''The test accuracy of the aggregated model on visited tasks.'''

for round in range(total_rounds):
    client_weight=[]
    client_model=[]
    client_head=[]
    client_acc=[]
    for client_idx,client in enumerate(clients):
        client.update_model(shared_model,head_library)
        client.head_need_ft=False # The head_library is shared so don't need ft.

        train_slice,slice_classes=train_data_manager.get_slice(client_idx)

        client_weight.append(len(train_slice.dataset))

        acc=client.train_model(train_slice,slice_classes,head_ft=(round!=0))
        
        client_acc.append(acc)
        
        model,head=client.get_model()
        
        client_model.append(model)
        client_head.append(head)
    
    print("Round {}".format(round))
    print("\tClient accuracies:",client_acc)
    client_accuracies.append(client_acc)
    shared_model=fcl_server.aggregate_shared_model(client_model,client_weight)
    head_library=fcl_server.aggregate_heads(client_head,client_weight)

    # Now do the test.
    pseudo_client.update_model(shared_model,head_library)
    test_slice,test_slice_classes=test_data_manager.get_slice(0)
    test_acc=pseudo_client.test_model(test_slice,test_slice_classes)
    print("\tTest accuracy:",test_acc)
    server_accuracy.append(test_acc)

    # Evaluate on past tasks.
    past_tasks=test_data_manager.get_past_tasks()
    past_task_acc=[]
    for past_task in past_tasks:
        ptask_slice,ptask_classes=sample_slice(past_task,
                                    configs['test_dataset_size'], 
                                    configs['test_batchsize'])
        ptask_acc=pseudo_client.test_model(ptask_slice,ptask_classes)
        past_task_acc.append(ptask_acc)
    
    task_accuracies.append(past_task_acc)

    # Step the data manager until the last round.
    if round!=total_rounds-1:
        train_data_manager.next_round()
        test_data_manager.next_round()

In [None]:
# Now dump the accuracy for future use.
import json
jsonstr=json.dumps({"client_acc":client_accuracies,"server_acc":server_accuracy,"task":task_accuracies})
print(jsonstr)