Skip to content

Commit

Permalink
fix conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
Abe404 committed Nov 13, 2021
2 parents 4c4f1b7 + 88bbb7a commit 264c999
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
28 changes: 28 additions & 0 deletions trainer/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,34 @@ def fake_cnn(tiles_for_gpu):
output.append((v > v_mean).astype(np.int8))
return np.array(output)

def get_in_w_out_w_pairs():
# matching pairs of input/output sizes for specific unet used
# 52 to 228 in incrememnts of 16 (sorted large to small)
in_w_list = sorted([52 + (x*16) for x in range(12)], reverse=True)

# output always 34 less than input
out_w_list = [x - 34 for x in in_w_list]
return list(zip(in_w_list, out_w_list))

def get_in_w_out_w_for_memory(num_classes):
# search for appropriate input size for GPU
# in_w, out_w = get_in_w_out_w_for_memory(num_classes)
net = UNet3D(im_channels=1, out_channels=num_classes*2).cuda()
net = torch.nn.DataParallel(net)
for in_w, out_w in get_in_w_out_w_pairs():
torch.cuda.empty_cache()
try:
# b, c, d, h, w
input_data = np.zeros((4, 1, 52, in_w, in_w))
output = net(torch.from_numpy(input_data).cuda().float())
del input_data
del output
torch.cuda.empty_cache()
return in_w, out_w
except Exception as e:
if 'out of memory' in str(e):
print(in_w, out_w, 'too big')
raise Exception('Could not find patch small enough for available GPU memory')


def get_latest_model_paths(model_dir, k):
Expand Down
11 changes: 11 additions & 0 deletions trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,17 @@ def main_loop(self, on_epoch_end=None):
time.sleep(1.0)


def add_config_shape(self, config):
new_config = copy.deepcopy(config)
num_classes = len(config['classes'])
in_w, out_w = model_utils.get_in_w_out_w_for_memory(num_classes)
print('found input width of', in_w, 'and output width of', out_w)
new_config['in_w'] = in_w
new_config['out_w'] = out_w
new_config['in_d'] = 52
new_config['out_d'] = 18
return new_config

def fix_config_paths(self, old_config):
""" get paths relative to local machine """
new_config = {}
Expand Down

0 comments on commit 264c999

Please sign in to comment.