This setting is the federated class incremental settings.

In [13]:
# Set the train parameters.
from configs import configs
configs['test_batchsize']=128
configs['train_batchsize']=128
configs['test_dataset_size']=1000
configs['train_dataset_size']=5000
configs['prior_mu']=0.0
configs['prior_sigma']=0.05
configs['snn_initialize_rounds']=3
configs['head_finetune_epoch']=3
configs['train_epoch']=8
configs['kl_weight']=1
configs['head_ft_learn_rate']=0.001
configs['learn_rate']=0.001
configs['monte_carlo_times']=100
configs['grad_mc_times']=10
configs['gpu']=False

number_of_clients=10
total_rounds=100

In [None]:
# Load the dataset.
from fcl_data_simulator.single_dataset import CIFAR100
cifar100=CIFAR100(resize_to=32)
trainset=cifar100["train"]
testset=cifar100["test"]

In [None]:
from fcl_data_simulator.dataset_utils import get_index_by_class
from fcl_data_simulator.dataset_utils import create_sampled_dataset

train_idx=get_index_by_class(trainset)
test_idx=get_index_by_class(testset)
class_idx=[list(range(0,25)),list(range(25,50)),
           list(range(50,75)),list(range(75,100))]
train_tasks=[]
test_tasks=[]
for task_class_idx in class_idx:
    train_task_idx=[]
    test_task_idx=[]
    for class_index in task_class_idx:
        train_task_idx+=train_idx[class_index]
        test_task_idx+=test_idx[class_index]
    train_tasks.append(create_sampled_dataset(trainset,train_task_idx,True))
    test_tasks.append(create_sampled_dataset(testset,test_task_idx,True))

In [None]:
from fcl_data_simulator.continual_policy import TaskSeparateContiunalPolicy
# Set the continual policies.
train_cp=TaskSeparateContiunalPolicy(train_tasks,25,
            configs['train_dataset_size'])
test_cp=TaskSeparateContiunalPolicy(test_tasks,25,
            configs['test_dataset_size'])

# Set the partition policy for the train dataset.
from fcl_data_simulator.partition_policy import PartitioningPolicies
partition_policy=PartitioningPolicies.dirichlet_nonIID_partitioning
partition_policy_args={"number_of_clients":number_of_clients}
test_partition_policy_args={"number_of_clients":1}

# Create the DataManager.
from fcl_data_simulator.data_manager import DataManager
train_data_manager=DataManager(configs["train_batchsize"],train_cp,
                    partition_policy,partition_policy_args)
test_data_manager=DataManager(configs["test_batchsize"],test_cp,
                    partition_policy,test_partition_policy_args)

print("The fcl data simulator is set.")

In the beginning several rounds, an SNN model is trained.

In [None]:
import torch
# Set the model.
shared_model=torch.nn.Sequential(
    torch.nn.Conv2d(3, 6, 5),
    torch.nn.ReLU(),torch.nn.MaxPool2d(2,2),
    torch.nn.Conv2d(6, 16, 5),
    torch.nn.ReLU(),torch.nn.MaxPool2d(2,2),
    torch.nn.Flatten(),
    torch.nn.Linear(400, 200),
)
# The head library.
head_library={}

In [None]:
from baselines.fcl_fedavg_client import fcl_fedavg_client
# Set the clients.
clients=[fcl_fedavg_client() for _ in range(number_of_clients)]
# The client used to test.
pseudo_client=fcl_fedavg_client()

In [None]:
import evaluation
from fcl_server import fcl_server
from fcl_data_simulator.dataset_utils import sample_slice
# Now the train process.

client_accuracies=[]
'''The accuracy of clients over rounds.'''
server_accuracy=[]
'''The test accuracy of the aggregated model over rounds.'''
task_accuracies=[]
'''The test accuracy of the aggregated model on visited tasks.'''

for round in range(configs['snn_initialize_rounds']):
    client_weight=[]
    client_model=[]
    client_head=[]
    client_acc=[]
    for client_idx,client in enumerate(clients):
        client.update_model(shared_model,head_library)
        client.head_need_ft=False # The head_library is shared so don't need ft.

        train_slice,slice_classes=train_data_manager.get_slice(client_idx)

        client_weight.append(len(train_slice.dataset))

        acc=client.train_model(train_slice,slice_classes,head_ft=(round!=0))
        
        client_acc.append(acc)
        
        model,head=client.get_model()
        
        client_model.append(model)
        client_head.append(head)
    
    print("Round {}".format(round))
    print("\tClient accuracies:",client_acc)
    client_accuracies.append(client_acc)
    shared_model=fcl_server.aggregate_shared_model(client_model,client_weight)
    head_library=fcl_server.aggregate_heads(client_head,client_weight)

    # Now do the test.
    pseudo_client.update_model(shared_model,head_library)
    test_slice,test_slice_classes=test_data_manager.get_slice(0)
    test_acc=pseudo_client.test_model(test_slice,test_slice_classes)
    print("\tTest accuracy:",test_acc)
    server_accuracy.append(test_acc)

    # Evaluate on past tasks.
    past_tasks=test_data_manager.get_past_tasks()
    past_task_acc=[]
    for past_task in past_tasks:
        ptask_slice,ptask_classes=sample_slice(past_task,
                                    configs['test_dataset_size'], 
                                    configs['test_batchsize'])
        ptask_acc=pseudo_client.test_model(ptask_slice,ptask_classes)
        past_task_acc.append(ptask_acc)
    
    task_accuracies.append(past_task_acc)

    # Step the data manager until the last round.
    if round!=total_rounds-1:
        train_data_manager.next_round()
        test_data_manager.next_round()

We now transform the SNN model into an BNN model, and train in the FedBNN way.

In [5]:
import torchbnn

snn_shared_model=shared_model
snn_head_library=head_library

# Set the model.
shared_model=torch.nn.Sequential(
    torchbnn.BayesConv2d(prior_mu=configs["prior_mu"],
                prior_sigma=configs["prior_sigma"],in_channels=3,
                out_channels=6,kernel_size=5),
    torch.nn.ReLU(),torch.nn.MaxPool2d(2,2),
    torchbnn.BayesConv2d(prior_mu=configs["prior_mu"],
                prior_sigma=configs["prior_sigma"],in_channels=6,
                out_channels=16,kernel_size=5),
    torch.nn.ReLU(),torch.nn.MaxPool2d(2,2),
    torch.nn.Flatten(),
    torchbnn.BayesLinear(prior_mu=configs["prior_mu"],
                prior_sigma=configs["prior_sigma"],in_features=400,
                out_features=200),
)
# The head library.
head_library={}

In [None]:
import snn_to_bnn
snn_to_bnn.head_library_transform(snn_head_library, head_library)
snn_to_bnn.model_transform(snn_shared_model, shared_model)

In [6]:
from fcl_client import fcl_client
# Set the clients.
clients=[fcl_client() for _ in range(number_of_clients)]
# The client used to test.
pseudo_client=fcl_client()

In [None]:
# Now the train process.
for round in range(configs['snn_initialize_rounds'],total_rounds):
    client_weight=[]
    client_model=[]
    client_head=[]
    client_acc=[]
    for client_idx,client in enumerate(clients):
        client.update_model(shared_model,head_library)
        client.head_need_ft=False # The head_library is shared so don't need ft.

        train_slice,slice_classes=train_data_manager.get_slice(client_idx)

        client_weight.append(len(train_slice.dataset))

        acc=client.train_model(train_slice,slice_classes,head_ft=(round!=0))
        
        client_acc.append(acc)
        
        model,head=client.get_model()
        
        client_model.append(model)
        client_head.append(head)
    
    print("Round {}".format(round))
    print("\tClient accuracies:",client_acc)
    client_accuracies.append(client_acc)
    
    shared_model=fcl_server.aggregate_shared_model(client_model,client_weight)
    head_library=fcl_server.aggregate_heads(client_head,client_weight)

    # Now do the test.
    pseudo_client.update_model(shared_model,head_library)
    test_slice,test_slice_classes=test_data_manager.get_slice(0)
    test_acc=pseudo_client.test_model(test_slice,test_slice_classes)
    print("\tTest accuracy:",test_acc)
    server_accuracy.append(test_acc)

    # Evaluate on past tasks.
    past_tasks=test_data_manager.get_past_tasks()
    past_task_acc=[]
    for past_task in past_tasks:
        ptask_slice,ptask_classes=sample_slice(past_task,
                                    configs['test_dataset_size'], 
                                    configs['test_batchsize'])
        ptask_acc=pseudo_client.test_model(ptask_slice,ptask_classes)
        past_task_acc.append(ptask_acc)
    
    task_accuracies.append(past_task_acc)

    # Step the data manager until the last round.
    if round!=total_rounds-1:
        train_data_manager.next_round()
        test_data_manager.next_round()

In [None]:
# Now dump the accuracy for future use.
import json
jsonstr=json.dumps({"client_acc":client_accuracies,"server_acc":server_accuracy,"task":task_accuracies})
print(jsonstr)