Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feature] FLIT for federated graph classification/regression (#87)
implemented FLIT
- Loading branch information
Showing
19 changed files
with
716 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -135,3 +135,5 @@ dmypy.json | |
.pyre/ | ||
|
||
.idea/ | ||
|
||
**/.DS_Store |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
176 changes: 176 additions & 0 deletions
176
federatedscope/core/splitters/graph/scaffold_lda_splitter.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
import logging | ||
import numpy as np | ||
import torch | ||
|
||
from rdkit import Chem | ||
from rdkit import RDLogger | ||
from rdkit.Chem.Scaffolds import MurckoScaffold | ||
from federatedscope.core.splitters.utils import dirichlet_distribution_noniid_slice | ||
from federatedscope.core.splitters.graph.scaffold_splitter import generate_scaffold | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
RDLogger.DisableLog('rdApp.*') | ||
|
||
class GenFeatures: | ||
r"""Implementation of 'CanonicalAtomFeaturizer' and 'CanonicalBondFeaturizer' in DGL. | ||
Source: https://lifesci.dgl.ai/_modules/dgllife/utils/featurizers.html | ||
Arguments: | ||
data: PyG.data in PyG.dataset. | ||
Returns: | ||
data: PyG.data, data passing featurizer. | ||
""" | ||
def __init__(self): | ||
self.symbols = [ | ||
'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg', | ||
'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl', | ||
'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', | ||
'H', 'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', | ||
'Zr', 'Cr', 'Pt', 'Hg', 'Pb', 'other' | ||
] | ||
|
||
self.hybridizations = [ | ||
Chem.rdchem.HybridizationType.SP, | ||
Chem.rdchem.HybridizationType.SP2, | ||
Chem.rdchem.HybridizationType.SP3, | ||
Chem.rdchem.HybridizationType.SP3D, | ||
Chem.rdchem.HybridizationType.SP3D2, | ||
'other', | ||
] | ||
|
||
self.stereos = [ | ||
Chem.rdchem.BondStereo.STEREONONE, | ||
Chem.rdchem.BondStereo.STEREOANY, | ||
Chem.rdchem.BondStereo.STEREOZ, | ||
Chem.rdchem.BondStereo.STEREOE, | ||
Chem.rdchem.BondStereo.STEREOCIS, | ||
Chem.rdchem.BondStereo.STEREOTRANS, | ||
] | ||
|
||
def __call__(self, data): | ||
mol = Chem.MolFromSmiles(data.smiles) | ||
|
||
xs = [] | ||
for atom in mol.GetAtoms(): | ||
symbol = [0.] * len(self.symbols) | ||
if atom.GetSymbol() in self.symbols: | ||
symbol[self.symbols.index(atom.GetSymbol())] = 1. | ||
else: | ||
symbol[self.symbols.index('other')] = 1. | ||
degree = [0.] * 10 | ||
degree[atom.GetDegree()] = 1. | ||
implicit = [0.] * 6 | ||
implicit[atom.GetImplicitValence()] = 1. | ||
formal_charge = atom.GetFormalCharge() | ||
radical_electrons = atom.GetNumRadicalElectrons() | ||
hybridization = [0.] * len(self.hybridizations) | ||
if atom.GetHybridization() in self.hybridizations: | ||
hybridization[self.hybridizations.index(atom.GetHybridization())] = 1. | ||
else: | ||
hybridization[self.hybridizations.index('other')] = 1. | ||
aromaticity = 1. if atom.GetIsAromatic() else 0. | ||
hydrogens = [0.] * 5 | ||
hydrogens[atom.GetTotalNumHs()] = 1. | ||
|
||
x = torch.tensor(symbol + degree + implicit + | ||
[formal_charge] + [radical_electrons] + | ||
hybridization + [aromaticity] + hydrogens) | ||
xs.append(x) | ||
|
||
data.x = torch.stack(xs, dim=0) | ||
|
||
edge_attrs = [] | ||
for bond in mol.GetBonds(): | ||
bond_type = bond.GetBondType() | ||
single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0. | ||
double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0. | ||
triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0. | ||
aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0. | ||
conjugation = 1. if bond.GetIsConjugated() else 0. | ||
ring = 1. if bond.IsInRing() else 0. | ||
stereo = [0.] * 6 | ||
stereo[self.stereos.index(bond.GetStereo())] = 1. | ||
|
||
edge_attr = torch.tensor( | ||
[single, double, triple, aromatic, conjugation, ring] + stereo) | ||
|
||
edge_attrs += [edge_attr, edge_attr] | ||
|
||
if len(edge_attrs) == 0: | ||
data.edge_index = torch.zeros((2, 0), dtype=torch.long) | ||
data.edge_attr = torch.zeros((0, 10), dtype=torch.float) | ||
else: | ||
num_atoms = mol.GetNumAtoms() | ||
feats = torch.stack(edge_attrs, dim=0) | ||
feats = torch.cat([feats, torch.zeros(feats.shape[0], 1)], dim=1) | ||
self_loop_feats = torch.zeros(num_atoms, feats.shape[1]) | ||
self_loop_feats[:, -1] = 1 | ||
feats = torch.cat([feats, self_loop_feats], dim=0) | ||
data.edge_attr = feats | ||
|
||
return data | ||
|
||
|
||
def gen_scaffold_lda_split(dataset, client_num=5, alpha=0.1): | ||
r""" | ||
return dict{ID:[idxs]} | ||
""" | ||
logger.info('Scaffold split might take minutes, please wait...') | ||
scaffolds = {} | ||
for idx, data in enumerate(dataset): | ||
smiles = data.smiles | ||
mol = Chem.MolFromSmiles(smiles) | ||
scaffold = generate_scaffold(smiles) | ||
if scaffold not in scaffolds: | ||
scaffolds[scaffold] = [idx] | ||
else: | ||
scaffolds[scaffold].append(idx) | ||
# Sort from largest to smallest scaffold sets | ||
scaffolds = {key: sorted(value) for key, value in scaffolds.items()} | ||
scaffold_list = [ | ||
list(scaffold_set) | ||
for (scaffold, | ||
scaffold_set) in sorted(scaffolds.items(), | ||
key=lambda x: (len(x[1]), x[1][0]), | ||
reverse=True) | ||
] | ||
label = np.zeros(len(dataset)) | ||
for i in range(len(scaffold_list)): | ||
label[scaffold_list[i]] = i+1 | ||
label = torch.LongTensor(label) | ||
# Split data to list | ||
idx_slice = dirichlet_distribution_noniid_slice(label, client_num, alpha) | ||
return idx_slice | ||
|
||
|
||
class ScaffoldLdaSplitter: | ||
r"""First adopt scaffold splitting and then assign the samples to clients according to Latent Dirichlet Allocation. | ||
Arguments: | ||
dataset (List or PyG.dataset): The molecular datasets. | ||
alpha (float): Partition hyperparameter in LDA, smaller alpha generates more extreme heterogeneous scenario. | ||
Returns: | ||
data_list (List(List(PyG.data))): Splited dataset via scaffold split. | ||
""" | ||
def __init__(self, client_num, alpha): | ||
self.client_num = client_num | ||
self.alpha = alpha | ||
|
||
def __call__(self, dataset): | ||
featurizer = GenFeatures() | ||
data = [] | ||
for ds in dataset: | ||
ds = featurizer(ds) | ||
data.append(ds) | ||
dataset = data | ||
idx_slice = gen_scaffold_lda_split(dataset, self.client_num, self.alpha) | ||
data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] | ||
return data_list | ||
|
||
def __repr__(self): | ||
return f'{self.__class__.__name__}()' |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
use_gpu: True | ||
device: 0 | ||
federate: | ||
mode: 'standalone' | ||
make_global_eval: True | ||
local_update_steps: 333 | ||
total_round_num: 30 | ||
client_num: 4 | ||
sample_client_num: 3 | ||
data: | ||
root: data/ | ||
splitter: scaffold_lda | ||
batch_size: 64 | ||
transform: ['AddSelfLoops'] | ||
splitter_args: [{'alpha': 0.1}] | ||
model: | ||
type: mpnn | ||
hidden: 64 | ||
task: graph | ||
out_channels: 2 | ||
flitplus: | ||
tmpFed: 0.5 | ||
factor_ema: 0.8 | ||
optimizer: | ||
type: 'Adam' | ||
lr: 0.0001 | ||
weight_decay: 0.00001 | ||
criterion: | ||
type: CrossEntropyLoss | ||
eval: | ||
freq: 50 | ||
metrics: ['roc_auc'] |
Oops, something went wrong.