In [89]:
import torch
import torch.nn as nn
import torch.utils.data
from copy import deepcopy

In [90]:
def get_total_params(model):
    total_params = sum(
    param.numel() for param in model.parameters()
    )
    return(total_params)

In [91]:
net_graph = {
    "IN" : ("b11",),
    "b11" : ("b12", "b21"),
    "b21" : ("b12", "b31",),
    "b12" : ("b22", "b13"),
    "b31" : ("b22",),
    "b22" : ("b13",),
    "b13" : ("out",),
    "out" : ("OUT",),
}

In [92]:
class NetBlock(nn.Module):
    def __init__(self, settings):
        super(NetBlock, self).__init__()
        self.backbone = settings["backbone"]
        self.in_blocks = settings["in_blocks"]
        
    def forward(self, in_list):
        x = torch.cat(in_list, dim=1)
        x = self.backbone(x)
        return(x)

In [93]:
class NetConnection(nn.Module):
    def __init__(self, connect_from, backbone=None):
        super(NetConnection, self).__init__()
        self.connect_from = connect_from
        if backbone is None: 
            self.backbone = nn.Identity()
        else:    
            self.backbone = backbone
        
    def forward(self, x):
        x = self.backbone(x)
        return(x)

In [94]:
block_11_settings = {
    "in_blocks" : {
        "IN" : nn.Identity(),
        },    
    "backbone" : nn.Conv3d(1, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_21_settings = {
    "in_blocks" : {
        "b11" : nn.MaxPool3d(2, 2),
        },
    "backbone" : nn.Conv3d(3, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_12_settings = {
    "in_blocks" : {
        "b21" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b11" : nn.Identity(),
        },
    "backbone" : nn.Conv3d(6, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_31_settings = {
    "in_blocks" : {
        "b21" : nn.MaxPool3d(2, 2),
        },
    "backbone" : nn.Conv3d(3, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_22_settings = {
    "in_blocks" : {
        "b31" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b21" : nn.Identity(),
        "b12" : nn.MaxPool3d(2, 2),
        },
    "backbone" : nn.Conv3d(9, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_13_settings = {
    "in_blocks" : {
        "b22" : nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True),
        "b12" : nn.Identity(),
        },
    "backbone" : nn.Conv3d(6, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_out_settings = {
     "in_blocks" : {
        "b13" : nn.Identity(),
        },
    "backbone" : nn.Conv3d(3, 1, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

net_blocks = { 
    "b11" : NetBlock(block_11_settings),
    "b21" : NetBlock(block_21_settings),
    "b31" : NetBlock(block_31_settings),
    "b12" : NetBlock(block_12_settings),
    "b22" : NetBlock(block_22_settings),
    "b13" : NetBlock(block_13_settings),
    "out" : NetBlock(block_out_settings),
}

#net_blocks

In [95]:
class Net_(nn.Module):
    def __init__(self, net_blocks, net_graph):
        super(Net, self).__init__()
        self.net_blocks = torch.nn.ModuleDict(net_blocks)
        self.net_graph = (net_graph)
        
    def forward(self, IN):
        block_outs = {"IN" : IN}
        
        for vertes in self.net_graph.keys():
            for block_name in self.net_graph[vertes]:
                if (block_name == "OUT"):
                    pass 
                
                elif (block_name not in block_outs.keys()):
                    block = self.net_blocks[block_name]
                    
                    #if block_outs.get(out) = None: continue 
                    in_list = []
                    for in_block in block.in_blocks:
                        resampled_out = block.in_blocks[in_block](block_outs[in_block])
                        in_list.append(resampled_out)
                    block_outs.update({block_name : (block(in_list))})
        #print("\n block outputs:", block_outs)
        return(block_outs["out"])

In [96]:
class Net(nn.Module):
    def __init__(self, net_blocks, net_graph):
        super(Net, self).__init__()
        self.net_blocks = torch.nn.ModuleDict(net_blocks)
        self.net_graph = (net_graph)
        
    #def check_outs(self):
        
    
        
    def forward(self, IN):
        block_outs = {"IN" : IN} #we will iteratively calculate all outputs
        
        for vertes in self.net_graph.keys():
            queue = [*self.net_graph[vertes]] 
            max_iters = (len(queue)**2)/2
            k = 0
            while len(queue) > 0: 
                if k > max_iters:
                    raise Exception('')
                block_name = queue.pop(0)
                if (block_name == "OUT"):
                    pass 
                
                elif (block_name not in block_outs.keys()):
                    block = self.net_blocks[block_name]
                    
                    
                    put_to_queue = False #If true, we can't calculate output now, set this block to the end of queue
                    for in_block in block.in_blocks:
                        if block_outs.get(in_block) is None:
                            queue.append(block_name)
                            put_to_queue = True
                            break
                    if put_to_queue:
                        continue
                

                    in_list = [] #upload all inputs to block
                    for in_block in block.in_blocks: 
                        resampled_out = block.in_blocks[in_block](block_outs[in_block])
                        in_list.append(resampled_out)
                    block_outs.update({block_name : (block(in_list))})
                k++
        return(block_outs["out"])

In [97]:
x = torch.rand(1, 1, 64, 64, 64)
GT = torch.rand(1, 1, 64, 64, 64) 

In [98]:
model = Net(net_blocks, net_graph)
print("total_params:", get_total_params(model))

total_params: 2368


In [99]:
model.to('cuda')
1

1

In [100]:
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(model.parameters())
n_epochs = 100

In [101]:
for epoch in range(n_epochs):
    x_cuda = x.to('cuda')
    GT_cuda = GT.to('cuda')
    
    out = model.forward(x_cuda)   
    loss = loss_fn(GT_cuda, out)

    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

0.38896670937538147
0.34376659989356995
0.30145424604415894
0.25598061084747314
0.20562002062797546
0.15224692225456238
0.10749007761478424
0.10021573305130005
0.14329808950424194
0.1486728936433792
0.12211531400680542
0.10039184242486954
0.09548762440681458
0.10177996009588242
0.11050626635551453
0.11646005511283875
0.11773843318223953
0.1146751195192337
0.10863392800092697
0.10166993737220764
0.09611101448535919
0.09396779537200928
0.09587757289409637
0.10005359351634979
0.10296913981437683
0.10237058252096176
0.098988838493824
0.09537513554096222
0.0934314951300621
0.0935005471110344
0.09477837383747101
0.09618861228227615
0.09695689380168915
0.09673995524644852
0.09571388363838196
0.09429434686899185
0.09308551251888275
0.09254296123981476
0.092784583568573
0.09348484873771667
0.0940474346280098
0.09403079748153687
0.09344612061977386
0.09266987442970276
0.09214068949222565
0.09202621877193451
0.09223079681396484
0.09250684827566147
0.09262386709451675
0.09249089658260345
0.0921664