In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import torch.nn as nn
import torch.utils.data

In [4]:
class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, act_fn=nn.ReLU(inplace=True)):
        super(conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
                      stride=stride, padding=padding, bias=bias),
            nn.BatchNorm3d(num_features=out_channels),
            act_fn,
            nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size,
                      stride=stride, padding=padding, bias=bias),
            nn.BatchNorm3d(num_features=out_channels),
            act_fn
        )

    def forward(self, x):
        x = self.conv(x)
        return x

In [5]:
class bottle_neck_connection(nn.Module):
    def __init__(self, in_channels, out_channels, bottle_channels,
                 bias=True, act_fn=nn.ReLU(inplace=True)):
        super(bottle_neck_connection, self).__init__()
        self.bottleneck1 = nn.Sequential(
            nn.Conv3d(in_channels=in_channels, out_channels=bottle_channels, kernel_size=3,
                      stride=2, padding=1, bias=bias),
            nn.BatchNorm3d(num_features=bottle_channels),
            act_fn,
        )
        
        self.bottleneck2 = nn.Sequential(
            nn.Conv3d(in_channels=bottle_channels, out_channels=bottle_channels, kernel_size=1,
                      stride=1, padding=0, bias=False),
            act_fn,
        )
        
        self.bottleneck3 = nn.Sequential(
            nn.ConvTranspose3d(in_channels=bottle_channels, out_channels=out_channels, kernel_size=3,
                               stride=2, padding=1, output_padding=1, bias=bias),
            nn.BatchNorm3d(num_features=out_channels),
            act_fn,
        )

    def forward(self, x):
        x = self.bottleneck1(x)
        x = self.bottleneck2(x)
        x = self.bottleneck3(x)
        return x

In [6]:
def get_total_params(model):
    total_params = sum(
    param.numel() for param in model.parameters()
    )
    return(total_params)

In [7]:
class NetBlock(nn.Module):
    def __init__(self, settings):
        super(NetBlock, self).__init__()
        self.backbone = settings["backbone"]
        self.in_blocks = torch.nn.ModuleDict(settings["in_blocks"])
        
    def forward(self, in_list):
        x = torch.cat(in_list, dim=1)
        x = self.backbone(x)
        return(x)

In [8]:
class Net(nn.Module):
    def __init__(self, net_blocks):
        super(Net, self).__init__()
        self.net_blocks = torch.nn.ModuleDict(net_blocks)
        self.net_graph = self.make_Net_graph()
        
        
    def make_Net_graph(self):
        graph = {}
        first_vertex = "IN"
        verified_verteсes = set( (first_vertex,))

        block_queue = [*self.net_blocks.keys()]
        queued_blocks_count = 0
        run = True

        while run:
            if len(block_queue)==0:
                break

            if queued_blocks_count>len(block_queue):
                print(graph)
                raise RuntimeError('Net::make_Net_graph::Error: Can\'t build graph, please check Net.net_blocks')

            block_name = block_queue.pop(0)
            in_blocks = self.net_blocks[block_name].in_blocks.keys()

            if sorted(list(in_blocks))== sorted(list(verified_verteсes.intersection(in_blocks))):
                verified_verteсes.add(block_name)
                for start in in_blocks:
                    if graph.get(start) is None:
                        graph.update({start : [block_name,]})
                    else: 
                        graph[start].append(block_name)
                queued_blocks_count = 0
            else:
                block_queue.append(block_name)
                queued_blocks_count+=1

        return graph
    

    def forward(self, IN):
        block_outs = {"IN" : IN} #we will iteratively calculate all outputs
        
        for vertes in self.net_graph.keys():
            for block_name in self.net_graph[vertes]:
                if (block_name not in block_outs.keys()):
                    block = self.net_blocks[block_name]

                    in_list = [] #upload all inputs to block
                    for in_block in block.in_blocks: 
                        resampled_out = block.in_blocks[in_block](block_outs[in_block])
                        in_list.append(resampled_out)
                    block_outs.update({block_name : (block(in_list))})
        return(block_outs["out"])
    
    
#     def forward_old(self, IN):
#         block_outs = {"IN" : IN} #we will iteratively calculate all outputs
        
#         for vertes in self.net_graph.keys():
#             queue = [*self.net_graph[vertes]] 
#             while len(queue) > 0:
#                 block_name = queue.pop(0)
#                 if (block_name not in block_outs.keys()):
#                     block = self.net_blocks[block_name]
                    
                    
#                     put_to_queue = False #If true, we can't calculate output now, set this block to the end of queue
#                     for in_block in block.in_blocks:
#                         if block_outs.get(in_block) is None:
#                             queue.append(block_name)
#                             put_to_queue = True
#                             break
#                     if put_to_queue:
#                         continue
                

#                     in_list = [] #upload all inputs to block
#                     for in_block in block.in_blocks: 
#                         resampled_out = block.in_blocks[in_block](block_outs[in_block])
#                         in_list.append(resampled_out)
#                     block_outs.update({block_name : (block(in_list))})
#         return(block_outs["out"])

In [9]:
channel_coef = 16
act_fn = nn.PReLU()

block_11_settings = {
    "in_blocks" : {
        "IN" : nn.Identity(),
        },    
    "backbone" : conv_block(1, channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_12_settings = {
    "in_blocks" : {
        "b21" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b11" : nn.Identity(),
        },
    "backbone" : conv_block(3*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_13_settings = {
    "in_blocks" : {
        "b22" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b12" : bottle_neck_connection(4*channel_coef, 4*channel_coef, 8*channel_coef, act_fn=act_fn),
        },
    "backbone" : conv_block(12*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_14_settings = {
    "in_blocks" : {
        "b23" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b13" : nn.Identity(),
        },
    "backbone" : conv_block(12*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_21_settings = {
    "in_blocks" : {
        "b11" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(channel_coef, 2*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_22_settings = {
    "in_blocks" : {
        "b31" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b21" : nn.Identity(),
        "b12" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(10*channel_coef, 8*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_23_settings = {
    "in_blocks" : {
        "b32" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b22" : nn.Identity(),
        "b13" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(16*channel_coef, 8*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_31_settings = {
    "in_blocks" : {
        "b21" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(2*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_32_settings = {
    "in_blocks" : {
        "b31" : bottle_neck_connection(4*channel_coef, 4*channel_coef, 8*channel_coef, act_fn=act_fn),
        "b22" : nn.MaxPool3d(2, 2),
        },
    "backbone" : conv_block(12*channel_coef, 4*channel_coef, kernel_size=3, stride=1, padding=1, act_fn=act_fn), 
}

block_out_settings = {
     "in_blocks" : {
        "b14" : nn.Identity(),
        },
    "backbone" : conv_block(4*channel_coef, 1, kernel_size=3, stride=1, padding=1, act_fn=nn.Sigmoid()),
}

net_blocks = { 
    "b11" : NetBlock(block_11_settings),
    "b12" : NetBlock(block_12_settings),
    "b13" : NetBlock(block_13_settings),
    "b14" : NetBlock(block_14_settings),
    "b21" : NetBlock(block_21_settings),
    "b22" : NetBlock(block_22_settings),
    "b23" : NetBlock(block_23_settings),
    "b31" : NetBlock(block_31_settings),
    "b32" : NetBlock(block_32_settings),
    "out" : NetBlock(block_out_settings),
}

In [10]:
model = Net(net_blocks)
print("total_params:", get_total_params(model))
model.net_graph

total_params: 4981938


{'IN': ['b11'],
 'b11': ['b21', 'b12'],
 'b21': ['b31', 'b12', 'b22'],
 'b31': ['b22', 'b32'],
 'b12': ['b22', 'b13'],
 'b22': ['b32', 'b13', 'b23'],
 'b32': ['b23'],
 'b13': ['b23', 'b14'],
 'b23': ['b14'],
 'b14': ['out']}

In [11]:
o = model(torch.rand(1, 1, 64, 64, 64))
print(o.shape)

torch.Size([1, 1, 64, 64, 64])


In [12]:
model = model.to('cuda')
1

1

In [13]:
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)
n_epochs = 100

In [16]:
x = torch.rand(4, 1, 64, 64, 64)
GT = torch.rand(4, 1, 64, 64, 64)

In [17]:
model = model.to('cuda')
for epoch in range(n_epochs):
    x_cuda = x.to('cuda')
    GT_cuda = GT.to('cuda')
    
    out = model.forward(x_cuda)   
    loss = loss_fn(GT_cuda, out)

    opt.zero_grad()
    loss.backward()
    opt.step()
    if epoch%10==0:
        print(loss.item())

0.5160942673683167
0.10768172144889832
0.09517958015203476
0.09066644310951233
0.08971656858921051
0.08916527777910233
0.0886305645108223
0.08824214339256287
0.08768998831510544
0.08686032146215439
