Skip to content

Commit

Permalink
removes finetuning functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
scaldas committed Jun 3, 2019
1 parent bf7e700 commit a88bddc
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 174 deletions.
3 changes: 0 additions & 3 deletions models/baseline_constants.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
"""Configuration file for common models/experiments"""

SIM_TIMES = ['small', 'medium', 'large']
"""list: Common sets of configuration for simulations"""

MAIN_PARAMS = {
'sent140': {
'small': (10, 2, 2),
Expand Down
114 changes: 20 additions & 94 deletions models/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@

import metrics.writer as metrics_writer

from baseline_constants import MAIN_PARAMS, MODEL_PARAMS, SIM_TIMES
from baseline_constants import MAIN_PARAMS, MODEL_PARAMS
from client import Client
from server import Server
from model import ServerModel

from utils.constants import DATASETS
from utils.args import parse_args
from utils.model_utils import read_data

STAT_METRICS_PATH = 'metrics/stat_metrics.csv'
Expand Down Expand Up @@ -61,7 +61,7 @@ def main():
server = Server(client_model)

# Create clients
train_clients, test_clients = setup_clients(args.dataset, client_model, args.is_client_split)
train_clients, test_clients = setup_clients(args.dataset, client_model)
train_ids, train_groups, num_train_samples, _ = server.get_clients_info(train_clients)
test_ids, test_groups, _, num_test_samples = server.get_clients_info(test_clients)
print('Clients in Total: %d train, %d test' % (len(train_clients), len(test_clients)))
Expand All @@ -70,21 +70,13 @@ def main():
print('--- Round 0 of %d ---' % (num_rounds))
train_stat_metrics = server.get_train_stats(train_clients)
print_metrics(train_stat_metrics, num_train_samples, prefix='train_')
stat_metrics = server.test_model(test_clients, query_fraction=0.1, num_epochs=args.num_epochs, batch_size=args.batch_size)
stat_metrics = server.test_model(test_clients)
metrics_writer.print_metrics(0, test_ids, stat_metrics, test_groups, num_test_samples, STAT_METRICS_PATH)
print_metrics(stat_metrics, num_test_samples, prefix='test_')

# Simulate training
for i in range(num_rounds):
print('--- Round %d of %d: Training %d Clients ---' % (i+1, num_rounds, clients_per_round))

# Test model on all clients
if i % eval_every == 0 or i == num_rounds:
train_stat_metrics = server.get_train_stats(clients)
print_metrics(train_stat_metrics, all_train_samples, prefix='train_')
stat_metrics = server.test_model(clients)
metrics_writer.print_metrics(i, all_ids, stat_metrics, all_groups, all_num_samples, STAT_METRICS_PATH)
print_metrics(stat_metrics, all_num_samples, prefix='test_')
print('--- Round %d of %d: Training %d Clients ---' % (i + 1, num_rounds, clients_per_round))

# Select clients to train this round
server.select_clients(i, online(train_clients), num_clients=clients_per_round)
Expand All @@ -98,92 +90,36 @@ def main():
metrics_writer.print_metrics(i, c_ids, sys_metics, c_groups, c_num_samples, SYS_METRICS_PATH)

# Test model
if (i+1) % eval_every == 0 or i == num_rounds:
if (i + 1) % eval_every == 0 or (i + 1) == num_rounds:
train_stat_metrics = server.get_train_stats(train_clients)
print_metrics(train_stat_metrics, num_train_samples, prefix='train_')
stat_metrics = server.test_model(test_clients, query_fraction=0.1, num_epochs=args.num_epochs, batch_size=args.batch_size)
metrics_writer.print_metrics((i+1), test_ids, stat_metrics, test_groups, num_test_samples, STAT_METRICS_PATH)
test_stat_metrics = server.test_model(test_clients)
metrics_writer.print_metrics((i + 1), test_ids, stat_metrics, test_groups, num_test_samples, STAT_METRICS_PATH)
print_metrics(stat_metrics, num_test_samples, prefix='test_')

# Save server model
# save_model(server_model, dataset, model)
ckpt_path = os.path.join('checkpoints', args.dataset)
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
save_path = server.save_model(os.path.join(ckpt_path, '{}.ckpt'.format(args.model)))
print('Model saved in path: %s' % save_path)

# Close models
# server_model.close()
client_model.close()

server.close_model()

def online(clients):
"""We assume all users are always online."""
return clients


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument('-dataset',
help='name of dataset;',
type=str,
choices=DATASETS,
required=True)
parser.add_argument('-model',
help='name of model;',
type=str,
required=True)
parser.add_argument('--num-rounds',
help='number of rounds to simulate;',
type=int,
default=-1)
parser.add_argument('--eval-every',
help='evaluate every ____ rounds;',
type=int,
default=-1)
parser.add_argument('--clients-per-round',
help='number of clients trained per round;',
type=int,
default=-1)
parser.add_argument('--batch_size',
help='batch size when clients train on data;',
type=int,
default=10)
parser.add_argument('--seed',
help='seed for random client sampling and batch splitting',
type=int,
default=0)
parser.add_argument('--is-client-split', action='store_true',
help='data split is according to clients')

# Minibatch doesn't support num_epochs, so make them mutually exclusive
epoch_capability_group = parser.add_mutually_exclusive_group()
epoch_capability_group.add_argument('--minibatch',
help='None for FedAvg, else fraction;',
type=float,
default=None)
epoch_capability_group.add_argument('--num_epochs',
help='number of epochs when clients train on data;',
type=int,
default=1)

parser.add_argument('-t',
help='simulation time: small, medium, or large;',
type=str,
choices=SIM_TIMES,
default='large')
parser.add_argument('-lr',
help='learning rate for local optimizers;',
type=float,
default=-1,
required=False)

return parser.parse_args()

def create_clients(users, groups, train_data, test_data, model):
if len(groups) == 0:
groups = [[] for _ in users]
clients = [Client(u, g, train_data[u], test_data[u], model) for u, g in zip(users, groups)]
return clients

def setup_clients(dataset, model=None, is_client_split=False):

def setup_clients(dataset, model=None):
"""Instantiates clients based on given train and test data directories.
Return:
Expand All @@ -192,25 +128,15 @@ def setup_clients(dataset, model=None, is_client_split=False):
train_data_dir = os.path.join('..', 'data', dataset, 'data', 'train')
test_data_dir = os.path.join('..', 'data', dataset, 'data', 'test')

conv_data = read_data(train_data_dir, test_data_dir, is_client_split)
(train_users, test_users), (train_groups, test_groups), train_data, test_data = conv_data
conv_data = read_data(train_data_dir, test_data_dir)
users, groups, train_data, test_data = conv_data

train_clients = create_clients(train_users, train_groups, train_data, test_data, model)
test_clients = create_clients(test_users, test_groups, train_data, test_data, model)
train_clients = create_clients(users, groups, train_data, test_data, model)
test_clients = create_clients(users, groups, train_data, test_data, model)

return train_clients, test_clients


def save_model(server_model, dataset, model):
"""Saves the given server model on checkpoints/dataset/model.ckpt."""
# Save server model
ckpt_path = os.path.join('checkpoints', dataset)
if not os.path.exists(ckpt_path):
os.makedirs(ckpt_path)
save_path = server_model.save(os.path.join(ckpt_path, '%s.ckpt' % model))
print('Model saved in path: %s' % save_path)


def print_metrics(metrics, weights, prefix=''):
"""Prints weighted averages of the given metrics.
Expand Down
5 changes: 0 additions & 5 deletions models/metl_femnist.sh

This file was deleted.

19 changes: 1 addition & 18 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,23 +81,6 @@ def train(self, data, num_epochs=1, batch_size=10):
update: List of np.ndarray weights, with each weight array
corresponding to a variable in the resulting graph
"""
<<<<<<< HEAD
=======
for _ in range(num_epochs):
for batched_x, batched_y in batch_data(data, batch_size):
input_data = self.process_x(batched_x)
target_data = self.process_y(batched_y)
with self.graph.as_default():
self.sess.run(self.train_op,
feed_dict={self.features: input_data, self.labels: target_data})
update = self.get_params()
comp = num_epochs * (len(data['y'])//batch_size) * batch_size * self.flops
return comp, update

# TODO: Confirm if num_epochs semantics should be changed to minibatches?
# Right now, this is equivalent to the train method (as it should be :p)
def finetune(self, data, num_epochs, batch_size):
>>>>>>> fb15579... parent bd814c55ee84a1a6184ab1155f2baf13cdab2016
for _ in range(num_epochs):
for batched_x, batched_y in batch_data(data, batch_size):
input_data = self.process_x(batched_x)
Expand Down Expand Up @@ -184,7 +167,7 @@ def update(self, updates):

weighted_vals = [np.zeros(np.shape(v), dtype=float) for v in updates[0][1]]

for i, update in enumerate(updates):
for _, update in enumerate(updates):
for j, weighted_val in enumerate(weighted_vals):
weighted_vals[j] = np.add(weighted_val, update[0] * update[1][j])

Expand Down
33 changes: 12 additions & 21 deletions models/server.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import numpy as np

from baseline_constants import BYTES_WRITTEN_KEY, BYTES_READ_KEY, LOCAL_COMPUTATIONS_KEY
from utils.model_utils import gen_frac_query

class Server:

Expand All @@ -11,7 +10,7 @@ def __init__(self, client_model):
self.selected_clients = []
self.updates = []

def select_clients(self, round, possible_clients, num_clients=20):
def select_clients(self, my_round, possible_clients, num_clients=20):
"""Selects num_clients clients randomly from possible_clients.
Note that within function, num_clients is set to
Expand All @@ -24,7 +23,7 @@ def select_clients(self, round, possible_clients, num_clients=20):
list of (num_train_samples, num_test_samples)
"""
num_clients = min(num_clients, len(possible_clients))
np.random.seed(round)
np.random.seed(my_round)
self.selected_clients = np.random.choice(possible_clients, num_clients, replace=False)

return [(c.num_train_samples, c.num_test_samples) for c in self.selected_clients]
Expand Down Expand Up @@ -68,24 +67,6 @@ def train_model(self, num_epochs=1, batch_size=10, minibatch=None, clients=None)

return sys_metrics

def metatest_model(self, clients, query_fraction, num_epochs=1, batch_size=10):
metrics = {}

cur_model_params = self.model
for client in clients:
support, query = gen_frac_query(client.eval_data, query_fraction)

client.model.set_params(cur_model_params)
_, finetuned_model = client.model.finetune(support, num_epochs, batch_size)
client.model.set_params(finetuned_model)

c_metrics = client.test(query)
metrics[client.id] = c_metrics
self.model = cur_model_params
self.client_model.set_params(cur_model_params)

return metrics

def update_model(self):
total_weight = 0.
base = [0] * len(self.updates[0][1])
Expand Down Expand Up @@ -138,3 +119,13 @@ def get_clients_info(self, clients):
num_test_samples = {c.id: c.num_test_samples for c in clients}
num_train_samples = {c.id: c.num_train_samples for c in clients}
return ids, groups, num_train_samples, num_test_samples

def save_model(self, path):
"""Saves the server model on checkpoints/dataset/model.ckpt."""
# Save server model
self.client_model.set_params(self.model)
model_sess = self.client_model.sess
return self.client_model.saver.save(model_sess, path)

def close_model(self):
self.client_model.close()
61 changes: 61 additions & 0 deletions models/utils/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import argparse

from .constants import DATASETS, SIM_TIMES


def parse_args():
parser = argparse.ArgumentParser()

parser.add_argument('-dataset',
help='name of dataset;',
type=str,
choices=DATASETS,
required=True)
parser.add_argument('-model',
help='name of model;',
type=str,
required=True)
parser.add_argument('--num-rounds',
help='number of rounds to simulate;',
type=int,
default=-1)
parser.add_argument('--eval-every',
help='evaluate every ____ rounds;',
type=int,
default=-1)
parser.add_argument('--clients-per-round',
help='number of clients trained per round;',
type=int,
default=-1)
parser.add_argument('--batch-size',
help='batch size when clients train on data;',
type=int,
default=10)
parser.add_argument('--seed',
help='seed for random client sampling and batch splitting',
type=int,
default=0)

# Minibatch doesn't support num_epochs, so make them mutually exclusive
epoch_capability_group = parser.add_mutually_exclusive_group()
epoch_capability_group.add_argument('--minibatch',
help='None for FedAvg, else fraction;',
type=float,
default=None)
epoch_capability_group.add_argument('--num-epochs',
help='number of epochs when clients train on data;',
type=int,
default=1)

parser.add_argument('-t',
help='simulation time: small, medium, or large;',
type=str,
choices=SIM_TIMES,
default='large')
parser.add_argument('-lr',
help='learning rate for local optimizers;',
type=float,
default=-1,
required=False)

return parser.parse_args()
5 changes: 4 additions & 1 deletion models/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
DATASETS = ['sent140', 'femnist', 'shakespeare']
DATASETS = ['sent140', 'femnist', 'shakespeare']

"""list: Common sets of configuration for simulations"""
SIM_TIMES = ['small', 'medium', 'large']
Loading

0 comments on commit a88bddc

Please sign in to comment.