Skip to content

Commit

Permalink
[feature] FLIT for federated graph classification/regression (#87)
Browse files Browse the repository at this point in the history
implemented FLIT
  • Loading branch information
wanghh7 committed May 23, 2022
1 parent 3023be7 commit e625a0d
Show file tree
Hide file tree
Showing 19 changed files with 716 additions and 10 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Expand Up @@ -135,3 +135,5 @@ dmypy.json
.pyre/

.idea/

**/.DS_Store
3 changes: 2 additions & 1 deletion federatedscope/core/auxiliaries/data_builder.py
Expand Up @@ -523,7 +523,8 @@ def get_data(config):
from federatedscope.gfl.dataloader import load_linklevel_dataset
data, modified_config = load_linklevel_dataset(config)
elif config.data.type.lower() in [
'hiv', 'proteins', 'imdb-binary'
'hiv', 'proteins', 'imdb-binary', 'bbbp', 'tox21', 'bace', 'sider', 'clintox',
'esol', 'freesolv', 'lipo'
] or config.data.type.startswith('graph_multi_domain'):
from federatedscope.gfl.dataloader import load_graphlevel_dataset
data, modified_config = load_graphlevel_dataset(config)
Expand Down
2 changes: 1 addition & 1 deletion federatedscope/core/auxiliaries/model_builder.py
Expand Up @@ -76,7 +76,7 @@ def get_model(model_config, local_data, backend='torch'):
elif model_config.type.lower().endswith('transformers'):
from federatedscope.nlp.model import get_transformer
model = get_transformer(model_config, local_data)
elif model_config.type.lower() in ['gcn', 'sage', 'gpr', 'gat', 'gin']:
elif model_config.type.lower() in ['gcn', 'sage', 'gpr', 'gat', 'gin', 'mpnn']:
from federatedscope.gfl.model import get_gnn
model = get_gnn(model_config, local_data)
elif model_config.type.lower() in ['vmfnet', 'hmfnet']:
Expand Down
3 changes: 3 additions & 0 deletions federatedscope/core/auxiliaries/splitter_builder.py
Expand Up @@ -36,6 +36,9 @@ def get_splitter(config):
elif config.data.splitter == 'scaffold':
from federatedscope.core.splitters.graph import ScaffoldSplitter
splitter = ScaffoldSplitter(client_num, **args)
elif config.data.splitter == 'scaffold_lda':
from federatedscope.core.splitters.graph import ScaffoldLdaSplitter
splitter = ScaffoldLdaSplitter(client_num, **args)
elif config.data.splitter == 'rand_chunk':
from federatedscope.core.splitters.graph import RandChunkSplitter
splitter = RandChunkSplitter(client_num, **args)
Expand Down
8 changes: 8 additions & 0 deletions federatedscope/core/auxiliaries/trainer_builder.py
Expand Up @@ -13,6 +13,10 @@
"linkminibatch_trainer": "LinkMiniBatchTrainer",
"nodefullbatch_trainer": "NodeFullBatchTrainer",
"nodeminibatch_trainer": "NodeMiniBatchTrainer",
"flitplustrainer": "FLITPlusTrainer",
"flittrainer": "FLITTrainer",
"fedvattrainer": "FedVATTrainer",
"fedfocaltrainer": "FedFocalTrainer",
"mftrainer": "MFTrainer",
}

Expand Down Expand Up @@ -59,6 +63,10 @@ def get_trainer(model=None,
'nodefullbatch_trainer', 'nodeminibatch_trainer'
]:
dict_path = "federatedscope.gfl.trainer.nodetrainer"
elif config.trainer.type.lower() in [
'flitplustrainer', 'flittrainer', 'fedvattrainer', 'fedfocaltrainer'
]:
dict_path = "federatedscope.gfl.flitplus.trainer"
elif config.trainer.type.lower() in ['mftrainer']:
dict_path = "federatedscope.mf.trainer.trainer"
else:
Expand Down
10 changes: 10 additions & 0 deletions federatedscope/core/configs/cfg_fl_algo.py
Expand Up @@ -68,6 +68,16 @@ def extend_fl_algo_cfg(cfg):
cfg.gcflplus.seq_length = 5
cfg.gcflplus.standardize = False

# ------------------------------------------------------------------------ #
# FLIT+ related options, gfl
# ------------------------------------------------------------------------ #
cfg.flitplus = CN()

cfg.flitplus.tmpFed = 0.5 # gamma in focal loss (Eq.4)
cfg.flitplus.lambdavat = 0.5 # lambda in phi (Eq.10)
cfg.flitplus.factor_ema = 0.8 # beta in omega (Eq.12)
cfg.flitplus.weightReg = 1.0 # balance lossLocalLabel and lossLocalVAT

# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_fl_algo_cfg)

Expand Down
3 changes: 2 additions & 1 deletion federatedscope/core/splitters/graph/__init__.py
Expand Up @@ -12,9 +12,10 @@
from federatedscope.core.splitters.graph.randchunk_splitter import RandChunkSplitter

from federatedscope.core.splitters.graph.analyzer import Analyzer
from federatedscope.core.splitters.graph.scaffold_lda_splitter import ScaffoldLdaSplitter


__all__ = [
'LouvainSplitter', 'RandomSplitter', 'RelTypeSplitter', 'ScaffoldSplitter',
'GraphTypeSplitter', 'RandChunkSplitter', 'Analyzer'
'GraphTypeSplitter', 'RandChunkSplitter', 'Analyzer', 'ScaffoldLdaSplitter'
]
176 changes: 176 additions & 0 deletions federatedscope/core/splitters/graph/scaffold_lda_splitter.py
@@ -0,0 +1,176 @@
import logging
import numpy as np
import torch

from rdkit import Chem
from rdkit import RDLogger
from rdkit.Chem.Scaffolds import MurckoScaffold
from federatedscope.core.splitters.utils import dirichlet_distribution_noniid_slice
from federatedscope.core.splitters.graph.scaffold_splitter import generate_scaffold

logger = logging.getLogger(__name__)

RDLogger.DisableLog('rdApp.*')

class GenFeatures:
r"""Implementation of 'CanonicalAtomFeaturizer' and 'CanonicalBondFeaturizer' in DGL.
Source: https://lifesci.dgl.ai/_modules/dgllife/utils/featurizers.html
Arguments:
data: PyG.data in PyG.dataset.
Returns:
data: PyG.data, data passing featurizer.
"""
def __init__(self):
self.symbols = [
'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn',
'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn',
'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'other'
]

self.hybridizations = [
Chem.rdchem.HybridizationType.SP,
Chem.rdchem.HybridizationType.SP2,
Chem.rdchem.HybridizationType.SP3,
Chem.rdchem.HybridizationType.SP3D,
Chem.rdchem.HybridizationType.SP3D2,
'other',
]

self.stereos = [
Chem.rdchem.BondStereo.STEREONONE,
Chem.rdchem.BondStereo.STEREOANY,
Chem.rdchem.BondStereo.STEREOZ,
Chem.rdchem.BondStereo.STEREOE,
Chem.rdchem.BondStereo.STEREOCIS,
Chem.rdchem.BondStereo.STEREOTRANS,
]

def __call__(self, data):
mol = Chem.MolFromSmiles(data.smiles)

xs = []
for atom in mol.GetAtoms():
symbol = [0.] * len(self.symbols)
if atom.GetSymbol() in self.symbols:
symbol[self.symbols.index(atom.GetSymbol())] = 1.
else:
symbol[self.symbols.index('other')] = 1.
degree = [0.] * 10
degree[atom.GetDegree()] = 1.
implicit = [0.] * 6
implicit[atom.GetImplicitValence()] = 1.
formal_charge = atom.GetFormalCharge()
radical_electrons = atom.GetNumRadicalElectrons()
hybridization = [0.] * len(self.hybridizations)
if atom.GetHybridization() in self.hybridizations:
hybridization[self.hybridizations.index(atom.GetHybridization())] = 1.
else:
hybridization[self.hybridizations.index('other')] = 1.
aromaticity = 1. if atom.GetIsAromatic() else 0.
hydrogens = [0.] * 5
hydrogens[atom.GetTotalNumHs()] = 1.

x = torch.tensor(symbol + degree + implicit +
[formal_charge] + [radical_electrons] +
hybridization + [aromaticity] + hydrogens)
xs.append(x)

data.x = torch.stack(xs, dim=0)

edge_attrs = []
for bond in mol.GetBonds():
bond_type = bond.GetBondType()
single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
conjugation = 1. if bond.GetIsConjugated() else 0.
ring = 1. if bond.IsInRing() else 0.
stereo = [0.] * 6
stereo[self.stereos.index(bond.GetStereo())] = 1.

edge_attr = torch.tensor(
[single, double, triple, aromatic, conjugation, ring] + stereo)

edge_attrs += [edge_attr, edge_attr]

if len(edge_attrs) == 0:
data.edge_index = torch.zeros((2, 0), dtype=torch.long)
data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
else:
num_atoms = mol.GetNumAtoms()
feats = torch.stack(edge_attrs, dim=0)
feats = torch.cat([feats, torch.zeros(feats.shape[0], 1)], dim=1)
self_loop_feats = torch.zeros(num_atoms, feats.shape[1])
self_loop_feats[:, -1] = 1
feats = torch.cat([feats, self_loop_feats], dim=0)
data.edge_attr = feats

return data


def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1):
r"""
return dict{ID:[idxs]}
"""
logger.info('Scaffold split might take minutes, please wait...')
scaffolds = {}
for idx, data in enumerate(dataset):
smiles = data.smiles
mol = Chem.MolFromSmiles(smiles)
scaffold = generate_scaffold(smiles)
if scaffold not in scaffolds:
scaffolds[scaffold] = [idx]
else:
scaffolds[scaffold].append(idx)
# Sort from largest to smallest scaffold sets
scaffolds = {key: sorted(value) for key, value in scaffolds.items()}
scaffold_list = [
list(scaffold_set)
for (scaffold,
scaffold_set) in sorted(scaffolds.items(),
key=lambda x: (len(x[1]), x[1][0]),
reverse=True)
]
label = np.zeros(len(dataset))
for i in range(len(scaffold_list)):
label[scaffold_list[i]] = i+1
label = torch.LongTensor(label)
# Split data to list
idx_slice = dirichlet_distribution_noniid_slice(label, client_num, alpha)
return idx_slice


class ScaffoldLdaSplitter:
r"""First adopt scaffold splitting and then assign the samples to clients according to Latent Dirichlet Allocation.
Arguments:
dataset (List or PyG.dataset): The molecular datasets.
alpha (float): Partition hyperparameter in LDA, smaller alpha generates more extreme heterogeneous scenario.
Returns:
data_list (List(List(PyG.data))): Splited dataset via scaffold split.
"""
def __init__(self, client_num, alpha):
self.client_num = client_num
self.alpha = alpha

def __call__(self, dataset):
featurizer = GenFeatures()
data = []
for ds in dataset:
ds = featurizer(ds)
data.append(ds)
dataset = data
idx_slice = gen_scaffold_lda_split(dataset, self.client_num, self.alpha)
data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice]
return data_list

def __repr__(self):
return f'{self.__class__.__name__}()'
Empty file.
32 changes: 32 additions & 0 deletions federatedscope/gfl/flitplus/fedalgo_cls.yaml
@@ -0,0 +1,32 @@
use_gpu: True
device: 0
federate:
mode: 'standalone'
make_global_eval: True
local_update_steps: 333
total_round_num: 30
client_num: 4
sample_client_num: 3
data:
root: data/
splitter: scaffold_lda
batch_size: 64
transform: ['AddSelfLoops']
splitter_args: [{'alpha': 0.1}]
model:
type: mpnn
hidden: 64
task: graph
out_channels: 2
flitplus:
tmpFed: 0.5
factor_ema: 0.8
optimizer:
type: 'Adam'
lr: 0.0001
weight_decay: 0.00001
criterion:
type: CrossEntropyLoss
eval:
freq: 50
metrics: ['roc_auc']

0 comments on commit e625a0d

Please sign in to comment.