diff --git a/scripts/train.py b/scripts/train.py old mode 100755 new mode 100644 index 37ea8435..9a270d74 --- a/scripts/train.py +++ b/scripts/train.py @@ -13,8 +13,13 @@ def train(args, train_loader, model, device, criterion, record = AverageMeter() model.train() - for iteration, (_, volume, label, class_weight, _) in enumerate(train_loader): + # for iteration, (_, volume, label, class_weight, _) in enumerate(train_loader): + for iteration, batch in enumerate(train_loader): + if args.task == 22: + _, volume, seg_mask, class_weight, _, label, out_skeleton = batch + else: + _, volume, label, class_weight, _ = batch volume, label = volume.to(device), label.to(device) class_weight = class_weight.to(device) output = model(volume) @@ -94,3 +99,4 @@ def main(): if __name__ == "__main__": main() + diff --git a/torch_connectomics/data/utils/__init__.py b/torch_connectomics/data/utils/__init__.py index 23cce11a..7c3e40cc 100755 --- a/torch_connectomics/data/utils/__init__.py +++ b/torch_connectomics/data/utils/__init__.py @@ -1 +1 @@ -from .functional_collate import collate_fn, collate_fn_test +from .functional_collate import collate_fn, collate_fn_test, collate_fn_skel diff --git a/torch_connectomics/data/utils/functional_collate.py b/torch_connectomics/data/utils/functional_collate.py index 7122708b..86e668c9 100755 --- a/torch_connectomics/data/utils/functional_collate.py +++ b/torch_connectomics/data/utils/functional_collate.py @@ -41,4 +41,20 @@ def collate_fn_plus(batch): for i in range(len(others)): extra[i] = torch.stack(others[i], 0) - return pos, out_input, out_label, weights, weight_factor, extra \ No newline at end of file + return pos, out_input, out_label, weights, weight_factor, extra + +def collate_fn_skel(batch): + """ + Puts each data field into a tensor with outer dimension batch size + :param batch: + :return: + """ + pos, out_input, out_label, weights, weight_factor, out_distance, out_skeleton = zip(*batch) + out_input = torch.stack(out_input, 0) + out_label = torch.stack(out_label, 0) + weights = torch.stack(weights, 0) + weight_factor = np.stack(weight_factor, 0) + out_distance = np.stack(out_distance, 0) + out_skeleton = np.stack(out_skeleton, 0) + + return pos, out_input, out_label, weights, weight_factor, out_distance, out_skeleton diff --git a/torch_connectomics/model/model_zoo/unetv0.py b/torch_connectomics/model/model_zoo/unetv0.py index a638f828..e4cebcb7 100755 --- a/torch_connectomics/model/model_zoo/unetv0.py +++ b/torch_connectomics/model/model_zoo/unetv0.py @@ -75,7 +75,7 @@ def __init__(self, in_channel=1, out_channel=3, filters=[32,64,128,256,256], act self.fconv = conv3d_bn_non(filters[0], out_channel, kernel_size=(3,3,3), padding=(1,1,1)) #final layer activation - if act='tanh': + if act == 'tanh': self.act = nn.Tanh() else: self.act = nn.Sigmoid() diff --git a/torch_connectomics/utils/net/arguments.py b/torch_connectomics/utils/net/arguments.py index 98c1f929..4fd750c5 100755 --- a/torch_connectomics/utils/net/arguments.py +++ b/torch_connectomics/utils/net/arguments.py @@ -7,8 +7,15 @@ def get_args(mode='train'): else: parser = argparse.ArgumentParser(description='Specify model inference arguments.') - # define tasks - # {0: neuron segmentationn, 1: synapse detection, 2: mitochondira segmentation} + """ + define tasks + {0: 'neuron segmentation', + 1: 'synapse detection', + 11: 'synapse polarity detection', + 2: 'mitochondria segmentation', + 22:'mitochondira segmentation with skeleton transform'} + """ + parser.add_argument('--task', type=int, default=0, help='specify the task') @@ -43,6 +50,9 @@ def get_args(mode='train'): parser.add_argument('-ln','--seg-name', default='seg-groundtruth2-malis.h5', help='Ground-truth label path') + parser.add_argument('-vm','--valid-mask', default=None, + help='Mask for the train images') + parser.add_argument('-ft','--finetune', type=bool, default=False, help='Fine-tune on previous model [Default: False]') diff --git a/torch_connectomics/utils/net/dataload.py b/torch_connectomics/utils/net/dataload.py old mode 100755 new mode 100644 index 23ddd2e4..51127330 --- a/torch_connectomics/utils/net/dataload.py +++ b/torch_connectomics/utils/net/dataload.py @@ -8,7 +8,7 @@ import torchvision.utils as vutils from torch_connectomics.data.dataset import AffinityDataset, SynapseDataset, SynapsePolarityDataset, MitoDataset, MitoSkeletonDataset -from torch_connectomics.data.utils import collate_fn, collate_fn_test +from torch_connectomics.data.utils import collate_fn, collate_fn_test, collate_fn_skel from torch_connectomics.data.augmentation import * TASK_MAP = {0: 'neuron segmentation', @@ -38,12 +38,19 @@ def get_input(args, model_io_size, mode='train'): if mode=='train': seg_name = args.seg_name.split('@') seg_name = [dir_name[0] + x for x in seg_name] + if args.valid_mask is not None: + mask_names = args.valid_mask.split('@') + mask_locations = [dir_name[0] + x for x in mask_names] # 1. load data model_input = [None]*len(img_name) if mode=='train': assert len(img_name)==len(seg_name) model_label = [None]*len(seg_name) + if args.valid_mask is not None: + assert len(img_name) == len(mask_locations) + model_mask = [None] * len(mask_locations) + for i in range(len(img_name)): model_input[i] = np.array(h5py.File(img_name[i], 'r')['main'])/255.0 @@ -61,7 +68,16 @@ def get_input(args, model_io_size, mode='train'): model_label[i] = np.pad(model_label[i], ((pad_size[0],pad_size[0]), (pad_size[1],pad_size[1]), (pad_size[2],pad_size[2])), 'reflect') + assert model_input[i].shape == model_label[i].shape + if args.valid_mask is not None: + model_mask[i] = np.array(h5py.File(mask_locations[i], 'r')['main']) + model_mask[i] = model_label[i].astype(np.float32) + print(f"mask shape: {model_mask[i].shape}") + model_label[i] = np.pad(model_label[i], ((pad_size[0],pad_size[0]), + (pad_size[1],pad_size[1]), + (pad_size[2],pad_size[2])), 'reflect') + assert model_input[i].shape == model_mask[i].shape if mode=='train': # setup augmentor @@ -102,11 +118,15 @@ def get_input(args, model_io_size, mode='train'): sample_label_size=sample_input_size, augmentor=augmentor, mode = 'train') if args.task == 22: # mitochondira segmentation with skeleton transform dataset = MitoSkeletonDataset(volume=model_input, label=model_label, sample_input_size=sample_input_size, - sample_label_size=sample_input_size, augmentor=augmentor, mode = 'train') + sample_label_size=sample_input_size, augmentor=augmentor, valid_mask=model_mask, mode='train') + img_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn_skel, + num_workers=args.num_cpu, pin_memory=True) + return img_loader img_loader = torch.utils.data.DataLoader( - dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn, - num_workers=args.num_cpu, pin_memory=True) + dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn, + num_workers=args.num_cpu, pin_memory=True) return img_loader else: @@ -118,8 +138,16 @@ def get_input(args, model_io_size, mode='train'): dataset = SynapseDataset(volume=model_input, label=None, sample_input_size=model_io_size, \ sample_label_size=None, sample_stride=model_io_size // 2, \ augmentor=None, mode='test') + elif args.task == 11: + dataset = SynapsePolarityDataset(volume=model_input, label=None, sample_input_size=model_io_size, + sample_label_size=None, sample_stride=model_io_size // 2, \ + augmentor=None, mode = 'test') elif args.task == 2: - dataset = MitoSkeletonDataset(volume=model_input, label=None, sample_input_size=model_io_size, \ + dataset = MitoDataset(volume=model_input, label=None, sample_input_size=model_io_size, \ + sample_label_size=None, sample_stride=model_io_size // 2, \ + augmentor=None, mode='test') + elif args.task == 22: + dataset = MitoSkeletonDataset(volume=model_input, label=None, sample_input_size=model_io_size, \ sample_label_size=None, sample_stride=model_io_size // 2, \ augmentor=None, mode='test') @@ -127,3 +155,4 @@ def get_input(args, model_io_size, mode='train'): dataset, batch_size=args.batch_size, shuffle=SHUFFLE, collate_fn = collate_fn_test, num_workers=args.num_cpu, pin_memory=True) return img_loader, volume_shape, pad_size +