Skip to content

Commit

Permalink
update networks
Browse files Browse the repository at this point in the history
  • Loading branch information
taigw committed Dec 6, 2019
1 parent eaeb496 commit 44a9e8c
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pymic/io/transform3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ def __call__(self, sample):
image= sample['image']
for chn in range(image.shape[0]):
mask = np.asarray(image[chn] > self.threshold[chn], image.dtype)
image[chn] = mask * image[chn]
image[chn] = mask * (image[chn] - self.threshold[chn])

sample['image'] = image
return sample
Expand Down
16 changes: 8 additions & 8 deletions pymic/layer/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,21 +66,21 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,

assert(dim == 2 or dim == 3)
if(dim == 2):
self.conv = nn.Conv2d(in_channels, in_channels,
kernel_size, stride, padding, dilation, groups = in_channels, bias = bias)
self.conv1x1 = nn.Conv2d(in_channels, out_channels,
kernel_size = 1, stride = stride, padding = 0, dilation = dilation, groups = conv_group, bias = bias)
self.conv = nn.Conv2d(out_channels, out_channels,
kernel_size, stride, padding, dilation, groups = out_channels, bias = bias)
if(self.norm_type == 'batch_norm'):
self.bn = nn.modules.BatchNorm2d(out_channels)
elif(self.norm_type == 'group_norm'):
self.bn = nn.GroupNorm(self.norm_group, out_channels)
elif(self.norm_type is not None):
raise ValueError("unsupported normalization method {0:}".format(norm_type))
else:
self.conv = nn.Conv3d(in_channels, in_channels,
kernel_size, stride, padding, dilation, groups = in_channels, bias = bias)
else:
self.conv1x1 = nn.Conv3d(in_channels, out_channels,
kernel_size = 1, stride = 0, padding = 0, dilation = 0, groups = conv_group, bias = bias)
kernel_size = 1, stride = stride, padding = 0, dilation = dilation, groups = conv_group, bias = bias)
self.conv = nn.Conv3d(out_channels, out_channels,
kernel_size, stride, padding, dilation, groups = out_channels, bias = bias)
if(self.norm_type == 'batch_norm'):
self.bn = nn.modules.BatchNorm3d(out_channels)
elif(self.norm_type == 'group_norm'):
Expand All @@ -89,8 +89,8 @@ def __init__(self, in_channels, out_channels, kernel_size, dim = 3,
raise ValueError("unsupported normalization method {0:}".format(norm_type))

def forward(self, x):
f = self.conv(x)
f = self.conv1x1(f)
f = self.conv1x1(x)
f = self.conv(f)
if(self.norm_type is not None):
f = self.bn(f)
if(self.acti_func is not None):
Expand Down
50 changes: 49 additions & 1 deletion pymic/net3d/unet2d5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,56 @@
from pymic.layer.activation import get_acti_func
from pymic.layer.convolution import ConvolutionLayer
from pymic.layer.deconvolution import DeconvolutionLayer
from pymic.net3d.unet2d5_ag import UNetBlock

class UNetBlock(nn.Module):
def __init__(self, in_channels, out_chnannels,
dim, resample, acti_func, acti_func_param):
super(UNetBlock, self).__init__()

self.in_chns = in_channels
self.out_chns = out_chnannels
self.dim = dim
self.resample = resample # resample should be 'down', 'up', or None
self.acti_func = acti_func

self.conv1 = ConvolutionLayer(in_channels, out_chnannels, kernel_size = 3, padding=1,
dim = self.dim, acti_func=get_acti_func(acti_func, acti_func_param))
self.conv2 = ConvolutionLayer(out_chnannels, out_chnannels, kernel_size = 3, padding=1,
dim = self.dim, acti_func=get_acti_func(acti_func, acti_func_param))
if(self.resample == 'down'):
if(self.dim == 2):
self.resample_layer = nn.MaxPool2d(kernel_size = 2, stride = 2)
else:
self.resample_layer = nn.MaxPool3d(kernel_size = 2, stride = 2)
elif(self.resample == 'up'):
self.resample_layer = DeconvolutionLayer(out_chnannels, out_chnannels, kernel_size = 2,
dim = self.dim, stride = 2, acti_func = get_acti_func(acti_func, acti_func_param))
else:
assert(self.resample == None)

def forward(self, x):
x_shape = list(x.shape)
if(self.dim == 2 and len(x_shape) == 5):
[N, C, D, H, W] = x_shape
new_shape = [N*D, C, H, W]
x = torch.transpose(x, 1, 2)
x = torch.reshape(x, new_shape)
output = self.conv1(x)
output = self.conv2(output)
resample = None
if(self.resample is not None):
resample = self.resample_layer(output)

if(self.dim == 2 and len(x_shape) == 5):
new_shape = [N, D] + list(output.shape)[1:]
output = torch.reshape(output, new_shape)
output = torch.transpose(output, 1, 2)
if(resample is not None):
resample_shape = list(resample.shape)
new_shape = [N, D] + resample_shape[1:]
resample = torch.reshape(resample, new_shape)
resample = torch.transpose(resample, 1, 2)
return output, resample

class UNet2D5(nn.Module):
def __init__(self, params):
Expand Down
31 changes: 20 additions & 11 deletions pymic/train_infer/train_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@
from pymic.util.parse_config import parse_config


class TrainInferAgent():
class TrainInferAgent(object):
def __init__(self, config, stage = 'train'):
self.config = config
self.stage = stage
assert(stage in ['train', 'inference', 'test'])

def __create_dataset(self):
def create_dataset(self):
root_dir = self.config['dataset']['root_dir']
train_csv = self.config['dataset'].get('train_csv', None)
valid_csv = self.config['dataset'].get('valid_csv', None)
Expand Down Expand Up @@ -74,11 +74,11 @@ def __create_dataset(self):
self.test_loder = torch.utils.data.DataLoader(test_dataset,
batch_size=batch_size, shuffle=False, num_workers=batch_size)

def __create_network(self):
def create_network(self):
self.net = get_network(self.config['network'])
self.net.double()

def __create_optimizer(self):
def create_optimizer(self):
self.optimizer = get_optimiser(self.config['training']['optimizer'],
self.net.parameters(),
self.config['training'])
Expand All @@ -91,7 +91,7 @@ def __create_optimizer(self):
self.config['training']['lr_gamma'],
last_epoch = last_iter)

def __train(self):
def train(self):
device = torch.device(self.config['training']['device_name'])
self.net.to(device)

Expand All @@ -111,7 +111,7 @@ def __train(self):
self.net.load_state_dict(self.checkpoint['model_state_dict'])
else:
self.checkpoint = None
self.__create_optimizer()
self.create_optimizer()

train_loss = 0
train_dice_list = []
Expand Down Expand Up @@ -218,7 +218,7 @@ def __train(self):
torch.save(save_dict, save_name)
summ_writer.close()

def __infer(self):
def infer(self):
device = torch.device(self.config['testing']['device_name'])
self.net.to(device)
# laod network parameters and set the network as evaluation mode
Expand Down Expand Up @@ -264,6 +264,15 @@ def test_time_dropout(m):
images = data['image'].double()
names = data['names']
print(names[0])
# for debug
# for i in range(images.shape[0]):
# image_i = images[i][0]
# label_i = images[i][0]
# image_name = "temp/{0:}_image.nii.gz".format(names[0])
# label_name = "temp/{0:}_label.nii.gz".format(names[0])
# save_nd_array_as_image(image_i, image_name, reference_name = None)
# save_nd_array_as_image(label_i, label_name, reference_name = None)
# continue
data['predict'] = volume_infer(images, self.net, device, class_num,
mini_batch_size, mini_patch_inshape, mini_patch_outshape, mini_patch_stride)

Expand Down Expand Up @@ -303,12 +312,12 @@ def test_time_dropout(m):
print("average testing time {0:}".format(avg_time))

def run(self):
agent.__create_dataset()
agent.__create_network()
self.create_dataset()
self.create_network()
if(self.stage == 'train'):
self.__train()
self.train()
else:
self.__infer()
self.infer()

if __name__ == "__main__":
if(len(sys.argv) < 3):
Expand Down
4 changes: 2 additions & 2 deletions pymic/util/rename_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ def rename_model_variable(input_file, output_file, input_var_list, output_var_li
torch.save(checkpoint, output_file)

if __name__ == "__main__":
input_file = '/home/disk2t/projects/dlls/training_fetal_brain/model2/unet2dres/model_15000.pt'
output_file = '/home/disk2t/projects/dlls/training_fetal_brain/model2/unet2dres/model_15000_rename.pt'
input_file = '/home/guotai/disk2t/projects/PyMIC/examples/prostate/model/unet3db/model_15000.pt'
output_file = '/home/guotai/disk2t/projects/PyMIC/examples/prostate/model/unet3db/model_15000_rename.pt'
input_var_list = ['conv.weight', 'conv.bias']
output_var_list= ['conv9.weight', 'conv9.bias']
rename_model_variable(input_file, output_file, input_var_list, output_var_list)

0 comments on commit 44a9e8c

Please sign in to comment.