Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion scripts/train.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,13 @@ def train(args, train_loader, model, device, criterion,
record = AverageMeter()
model.train()

for iteration, (_, volume, label, class_weight, _) in enumerate(train_loader):
# for iteration, (_, volume, label, class_weight, _) in enumerate(train_loader):
for iteration, batch in enumerate(train_loader):

if args.task == 22:
_, volume, seg_mask, class_weight, _, label, out_skeleton = batch
else:
_, volume, label, class_weight, _ = batch
volume, label = volume.to(device), label.to(device)
class_weight = class_weight.to(device)
output = model(volume)
Expand Down Expand Up @@ -94,3 +99,4 @@ def main():

if __name__ == "__main__":
main()

2 changes: 1 addition & 1 deletion torch_connectomics/data/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .functional_collate import collate_fn, collate_fn_test
from .functional_collate import collate_fn, collate_fn_test, collate_fn_skel
18 changes: 17 additions & 1 deletion torch_connectomics/data/utils/functional_collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,20 @@ def collate_fn_plus(batch):
for i in range(len(others)):
extra[i] = torch.stack(others[i], 0)

return pos, out_input, out_label, weights, weight_factor, extra
return pos, out_input, out_label, weights, weight_factor, extra

def collate_fn_skel(batch):
"""
Puts each data field into a tensor with outer dimension batch size
:param batch:
:return:
"""
pos, out_input, out_label, weights, weight_factor, out_distance, out_skeleton = zip(*batch)
out_input = torch.stack(out_input, 0)
out_label = torch.stack(out_label, 0)
weights = torch.stack(weights, 0)
weight_factor = np.stack(weight_factor, 0)
out_distance = np.stack(out_distance, 0)
out_skeleton = np.stack(out_skeleton, 0)

return pos, out_input, out_label, weights, weight_factor, out_distance, out_skeleton
2 changes: 1 addition & 1 deletion torch_connectomics/model/model_zoo/unetv0.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(self, in_channel=1, out_channel=3, filters=[32,64,128,256,256], act
self.fconv = conv3d_bn_non(filters[0], out_channel, kernel_size=(3,3,3), padding=(1,1,1))

#final layer activation
if act='tanh':
if act == 'tanh':
self.act = nn.Tanh()
else:
self.act = nn.Sigmoid()
Expand Down
14 changes: 12 additions & 2 deletions torch_connectomics/utils/net/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@ def get_args(mode='train'):
else:
parser = argparse.ArgumentParser(description='Specify model inference arguments.')

# define tasks
# {0: neuron segmentationn, 1: synapse detection, 2: mitochondira segmentation}
"""
define tasks
{0: 'neuron segmentation',
1: 'synapse detection',
11: 'synapse polarity detection',
2: 'mitochondria segmentation',
22:'mitochondira segmentation with skeleton transform'}
"""

parser.add_argument('--task', type=int, default=0,
help='specify the task')

Expand Down Expand Up @@ -43,6 +50,9 @@ def get_args(mode='train'):
parser.add_argument('-ln','--seg-name', default='seg-groundtruth2-malis.h5',
help='Ground-truth label path')

parser.add_argument('-vm','--valid-mask', default=None,
help='Mask for the train images')

parser.add_argument('-ft','--finetune', type=bool, default=False,
help='Fine-tune on previous model [Default: False]')

Expand Down
39 changes: 34 additions & 5 deletions torch_connectomics/utils/net/dataload.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torchvision.utils as vutils

from torch_connectomics.data.dataset import AffinityDataset, SynapseDataset, SynapsePolarityDataset, MitoDataset, MitoSkeletonDataset
from torch_connectomics.data.utils import collate_fn, collate_fn_test
from torch_connectomics.data.utils import collate_fn, collate_fn_test, collate_fn_skel
from torch_connectomics.data.augmentation import *

TASK_MAP = {0: 'neuron segmentation',
Expand Down Expand Up @@ -38,12 +38,19 @@ def get_input(args, model_io_size, mode='train'):
if mode=='train':
seg_name = args.seg_name.split('@')
seg_name = [dir_name[0] + x for x in seg_name]
if args.valid_mask is not None:
mask_names = args.valid_mask.split('@')
mask_locations = [dir_name[0] + x for x in mask_names]

# 1. load data
model_input = [None]*len(img_name)
if mode=='train':
assert len(img_name)==len(seg_name)
model_label = [None]*len(seg_name)
if args.valid_mask is not None:
assert len(img_name) == len(mask_locations)
model_mask = [None] * len(mask_locations)


for i in range(len(img_name)):
model_input[i] = np.array(h5py.File(img_name[i], 'r')['main'])/255.0
Expand All @@ -61,7 +68,16 @@ def get_input(args, model_io_size, mode='train'):
model_label[i] = np.pad(model_label[i], ((pad_size[0],pad_size[0]),
(pad_size[1],pad_size[1]),
(pad_size[2],pad_size[2])), 'reflect')

assert model_input[i].shape == model_label[i].shape
if args.valid_mask is not None:
model_mask[i] = np.array(h5py.File(mask_locations[i], 'r')['main'])
model_mask[i] = model_label[i].astype(np.float32)
print(f"mask shape: {model_mask[i].shape}")
model_label[i] = np.pad(model_label[i], ((pad_size[0],pad_size[0]),
(pad_size[1],pad_size[1]),
(pad_size[2],pad_size[2])), 'reflect')
assert model_input[i].shape == model_mask[i].shape

if mode=='train':
# setup augmentor
Expand Down Expand Up @@ -102,11 +118,15 @@ def get_input(args, model_io_size, mode='train'):
sample_label_size=sample_input_size, augmentor=augmentor, mode = 'train')
if args.task == 22: # mitochondira segmentation with skeleton transform
dataset = MitoSkeletonDataset(volume=model_input, label=model_label, sample_input_size=sample_input_size,
sample_label_size=sample_input_size, augmentor=augmentor, mode = 'train')
sample_label_size=sample_input_size, augmentor=augmentor, valid_mask=model_mask, mode='train')
img_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn_skel,
num_workers=args.num_cpu, pin_memory=True)
return img_loader

img_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn,
num_workers=args.num_cpu, pin_memory=True)
dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn,
num_workers=args.num_cpu, pin_memory=True)
return img_loader

else:
Expand All @@ -118,12 +138,21 @@ def get_input(args, model_io_size, mode='train'):
dataset = SynapseDataset(volume=model_input, label=None, sample_input_size=model_io_size, \
sample_label_size=None, sample_stride=model_io_size // 2, \
augmentor=None, mode='test')
elif args.task == 11:
dataset = SynapsePolarityDataset(volume=model_input, label=None, sample_input_size=model_io_size,
sample_label_size=None, sample_stride=model_io_size // 2, \
augmentor=None, mode = 'test')
elif args.task == 2:
dataset = MitoSkeletonDataset(volume=model_input, label=None, sample_input_size=model_io_size, \
dataset = MitoDataset(volume=model_input, label=None, sample_input_size=model_io_size, \
sample_label_size=None, sample_stride=model_io_size // 2, \
augmentor=None, mode='test')
elif args.task == 22:
dataset = MitoSkeletonDataset(volume=model_input, label=None, sample_input_size=model_io_size, \
sample_label_size=None, sample_stride=model_io_size // 2, \
augmentor=None, mode='test')

img_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn_test,
num_workers=args.num_cpu, pin_memory=True)
return img_loader, volume_shape, pad_size