In [None]:
import torch
from torch import nn
from MedicalNet.models import new_resnet as resnet
from torch.nn import Conv3d, BatchNorm3d, ReLU

opt: model, model_depth, n_seg_classes, input_W, input_H, input_D, resnet_shortcut, no_cuda, gpu_id, phase, pretrain_path, new_layer_names

In [None]:
class opt:
    model = 'resnet'
    model_depth = 34
    n_seg_classes = 1 # We only segment neurons
    input_W = 1
    input_H = 1
    input_D = 1
    resnet_shortcut = 'B' #???
    no_cuda = False
    gpu_id = []
    phase = 'train'
    pretrain_path = '/cvlabdata1/home/zakariya/SegmentingBrains/codes/MedicalNet/pretrain/resnet_34_23dataset.pth'
    new_layer_names = []

In [None]:
def generate_model(opt):
    assert opt.model in [
        'resnet'
    ]

    if opt.model == 'resnet':
        assert opt.model_depth in [10, 18, 34, 50, 101, 152, 200]
        
        if opt.model_depth == 10:
            model = resnet.resnet10(
                sample_input_W=opt.input_W,
                sample_input_H=opt.input_H,
                sample_input_D=opt.input_D,
                shortcut_type=opt.resnet_shortcut,
                no_cuda=opt.no_cuda,
                num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 18:
            model = resnet.resnet18(
                sample_input_W=opt.input_W,
                sample_input_H=opt.input_H,
                sample_input_D=opt.input_D,
                shortcut_type=opt.resnet_shortcut,
                no_cuda=opt.no_cuda,
                num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 34:
            model = resnet.resnet34(
                sample_input_W=opt.input_W,
                sample_input_H=opt.input_H,
                sample_input_D=opt.input_D,
                shortcut_type=opt.resnet_shortcut,
                no_cuda=opt.no_cuda,
                num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 50:
            model = resnet.resnet50(
                sample_input_W=opt.input_W,
                sample_input_H=opt.input_H,
                sample_input_D=opt.input_D,
                shortcut_type=opt.resnet_shortcut,
                no_cuda=opt.no_cuda,
                num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 101:
            model = resnet.resnet101(
                sample_input_W=opt.input_W,
                sample_input_H=opt.input_H,
                sample_input_D=opt.input_D,
                shortcut_type=opt.resnet_shortcut,
                no_cuda=opt.no_cuda,
                num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 152:
            model = resnet.resnet152(
                sample_input_W=opt.input_W,
                sample_input_H=opt.input_H,
                sample_input_D=opt.input_D,
                shortcut_type=opt.resnet_shortcut,
                no_cuda=opt.no_cuda,
                num_seg_classes=opt.n_seg_classes)
        elif opt.model_depth == 200:
            model = resnet.resnet200(
                sample_input_W=opt.input_W,
                sample_input_H=opt.input_H,
                sample_input_D=opt.input_D,
                shortcut_type=opt.resnet_shortcut,
                no_cuda=opt.no_cuda,
                num_seg_classes=opt.n_seg_classes)
    '''
        if not opt.no_cuda:
                if len(opt.gpu_id) > 1:
                    model = model.cuda() 
                    model = nn.DataParallel(model, device_ids=opt.gpu_id)
                    net_dict = model.state_dict() 
                else:
                    import os
                    os.environ["CUDA_VISIBLE_DEVICES"]=str(opt.gpu_id[0])
                    model = model.cuda() 
                    model = nn.DataParallel(model, device_ids=None)
                    net_dict = model.state_dict()
    '''  

    model = model.cuda() 
    net_dict = model.state_dict() 

    # load pretrain
    if opt.phase == 'train' and opt.pretrain_path:
        print ('loading pretrained model {}'.format(opt.pretrain_path))
        if str(opt.model_depth) not in opt.pretrain_path:
            raise Exception('Loaded wrong model number')
        pretrain = torch.load(opt.pretrain_path)
        #pretrain_dict = {k: v for k, v in pretrain['state_dict'].items() if k in net_dict.keys()} #doesnt load anything, arent the same
        pretrain_dict = {}
        for k, v in pretrain['state_dict'].items():
            for key in net_dict.keys():
                if key == k.replace('module.', ''):
                    pretrain_dict[key] = v
         
        net_dict.update(pretrain_dict)
        model.load_state_dict(net_dict)

        new_parameters = [] 
        for pname, p in model.named_parameters():
            for layer_name in opt.new_layer_names:
                if pname.find(layer_name) >= 0:
                    new_parameters.append(p)
                    break

        new_parameters_id = list(map(id, new_parameters))
        base_parameters = list(filter(lambda p: id(p) not in new_parameters_id, model.parameters()))
        parameters = {'base_parameters': base_parameters, 
                      'new_parameters': new_parameters}

        return model, parameters

    return model, model.parameters()


Bottleneck's conv3d downsample isn't applied directly, it's added to the output of the bottelneck

In [None]:
pr = opt()

In [None]:
model, params = generate_model(pr)

In [None]:
x = torch.rand((1, 1, 64, 64, 64)).to('cuda')

In [None]:
seq1 = nn.Sequential(
    nn.Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False),
    nn.BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
    nn.ReLU(inplace=True)).to('cuda')
maxp = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1, dilation=1, ceil_mode=False).to('cuda')
layer1 = nn.Sequential(maxp, model.layer1)
layer2 = model.layer2
layer3 = model.layer3
layer4 = model.layer4
conv_s = model.conv_seg

In [None]:
#print(seq1(x).size())
print(layer1(seq1(x)).size())
print(layer2(layer1(seq1(x))).size())
print(layer3(layer2(layer1(seq1(x)))).size())
print(layer4(layer3(layer2(layer1(seq1(x))))).size())

In [None]:
x1,x2,x3,x4,x5 = model(x)

In [None]:
model

In [None]:
print(x1.size())
print(x2.size())
print(x3.size())
print(x4.size())
print(x5.size())

## ResUnet Network

min 64x64x64

In [1]:
import torch
from torch import nn
from MedicalNet.models import new_resnet as resnet
from torch.nn import Conv3d, BatchNorm3d, ReLU
from topoloss4neurons.networksResUnet import ResUNet, UNet
from topoloss4neurons.networksResUnetV2 import ResUNetV2

In [2]:
x = torch.rand((1, 1, 64, 64, 64)).to('cuda')
y = torch.rand((1, 1, 64, 64, 64))

In [3]:
network = ResUNet(n_levels=4, three_dimensional=True, out_channels=1, n_convs=2, m_channels=32) # mchannels = 32 with resnet 34 and levels = 4 # or add 2 layer block
network2 = UNet(n_levels=5, three_dimensional=True, out_channels=1, n_convs=2, m_channels=32)
networkV2 = ResUNetV2(n_levels=5, three_dimensional=True, out_channels=1, n_convs=2, m_channels=32)


loading pretrained model /cvlabdata1/home/zakariya/SegmentingBrains/codes/MedicalNet/pretrain/resnet_34_23dataset.pth
loading pretrained model /cvlabdata1/home/zakariya/SegmentingBrains/codes/MedicalNet/pretrain/resnet_34_23dataset.pth


In [4]:
network

ResUNet(
  (down_path): ResNet(
    (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 1, 1), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, e

In [5]:
network(x).size()

torch.Size([1, 1, 64, 64, 64])

In [None]:
networkV2

not use 5 levels of first because will need 5 concatenations but resnet only has 4 layers