Skip to content

Commit

Permalink
collate from different machines
Browse files Browse the repository at this point in the history
  • Loading branch information
gngdb committed Feb 9, 2019
1 parent 10604c4 commit c723e08
Showing 1 changed file with 40 additions and 35 deletions.
75 changes: 40 additions & 35 deletions collate_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def parse_checkpoint(ckpt_name, ckpt_contents):
results['val_errors'] = [float(x) for x in ckpt_contents['val_errors']]
results['train_errors'] = [float(x) for x in ckpt_contents['train_errors']]
# hard part: count parameters by making an instance of the network
network = {'wrn_28_10': 'WideResNet', 'darts': 'DARTS'}[ckpt_name.split(".")[0]]
h,w = {'WideResNet': (32,32), 'DARTS': (32,32)}[network]
network = {'wrn_28_10': 'WideResNet', 'darts': 'DARTS', 'wrn_50_2': 'WRN_50_2'}[ckpt_name.split(".")[0]]
h,w = {'WideResNet': (32,32), 'DARTS': (32,32), 'WRN_50_2': (224,224)}[network]
SavedConv, SavedBlock = what_conv_block(ckpt_contents['conv'],
ckpt_contents['blocktype'], ckpt_contents['module'])
model = build_network(SavedConv, SavedBlock, network)
Expand Down Expand Up @@ -80,39 +80,44 @@ def keep_oldest(collated, ckpt_name, ckpt_contents):
return collated[ckpt_name]

def main():
# prepare directory
if not os.path.exists("collate"):
os.mkdir("collate")
else:
# clean up directory
old_ckpts = os.listdir("collate")
for c in old_ckpts:
os.remove(os.path.join("collate", c))

# read the schedule from json
json_path = sys.argv[1]
with open(json_path, "r") as f:
schedule = json.load(f)
# make a list of all the checkpoint files we need to check
checkpoints = []
for e in schedule:
checkpoints.append(ckpt_name(e)+".t7")
# look for these checkpoints on every machine we know about
collated = []
for m in tqdm(machines, desc='machine'):
# connect to the remote machine
hostname, directory = m.split(":")
checkpoint_dir = os.path.join(directory, "checkpoints")
completed = subprocess.run(f"ssh {hostname} ls {checkpoint_dir}".split(" "), stdout=PIPE, stderr=PIPE)
checkpoints_on_remote = completed.stdout.decode().split("\n")

# look for overlap between that and the checkpoints we care about
overlap = list(set(checkpoints_on_remote) & set(checkpoints))
for checkpoint in tqdm(overlap, desc="copying"):
checkpoint_loc = os.path.join(checkpoint_dir, checkpoint)
checkpoint_dest = f"collate/{hostname}.{checkpoint}"
if not os.path.exists(checkpoint_dest):
subprocess.run(f"scp {hostname}:{checkpoint_loc} {checkpoint_dest}".split(" "), stdout=PIPE, stderr=PIPE)
try:
# read the schedule from json
json_path = sys.argv[1]
with open(json_path, "r") as f:
schedule = json.load(f)

# prepare directory
if not os.path.exists("collate"):
os.mkdir("collate")
else:
# clean up directory
old_ckpts = os.listdir("collate")
for c in old_ckpts:
os.remove(os.path.join("collate", c))

# make a list of all the checkpoint files we need to check
checkpoints = []
for e in schedule:
checkpoints.append(ckpt_name(e)+".t7")
# look for these checkpoints on every machine we know about
collated = []
for m in tqdm(machines, desc='machine'):
# connect to the remote machine
hostname, directory = m.split(":")
checkpoint_dir = os.path.join(directory, "checkpoints")
completed = subprocess.run(f"ssh {hostname} ls {checkpoint_dir}".split(" "), stdout=PIPE, stderr=PIPE)
checkpoints_on_remote = completed.stdout.decode().split("\n")

# look for overlap between that and the checkpoints we care about
overlap = list(set(checkpoints_on_remote) & set(checkpoints))
for checkpoint in tqdm(overlap, desc="copying"):
checkpoint_loc = os.path.join(checkpoint_dir, checkpoint)
checkpoint_dest = f"collate/{hostname}.{checkpoint}"
if not os.path.exists(checkpoint_dest):
subprocess.run(f"scp {hostname}:{checkpoint_loc} {checkpoint_dest}".split(" "), stdout=PIPE, stderr=PIPE)

except IndexError:
pass

# iterate over copied files
collated = OrderedDict()
Expand Down

0 comments on commit c723e08

Please sign in to comment.