In [48]:
import torch
import torch.nn as nn
import torch.utils.data
from copy import deepcopy

In [49]:
def get_total_params(model):
    total_params = sum(
    param.numel() for param in model.parameters()
    )
    return(total_params)

In [50]:
class NetConnection(nn.Module):
    def __init__(self, connect_from, backbone=None):
        super(NetConnection, self).__init__()
        self.connect_from = connect_from
        if backbone is None: 
            self.backbone = nn.Identity()
        else:    
            self.backbone = backbone
        
    def forward(self, x):
        x = self.backbone(x)
        return(x)

In [92]:
net_graph = {
    "IN" :  (NetConnection("b11"),),
    "b11" : (NetConnection("b21", nn.MaxPool3d(2, 2)),
             NetConnection("b12")),
    "b21" : (NetConnection("b12", nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)),),
    "b12" : (NetConnection("out"),),
    "out" : (NetConnection("OUT"),),
}

In [93]:
class NetBlock(nn.Module):
    def __init__(self, settings):
        super(NetBlock, self).__init__()
        self.in_blocks = settings["in"]
        self.backbone = settings["backbone"]
        
    def forward(self, in_list):
        x = torch.cat(in_list, dim=1)
        x = self.backbone(x)
        return(x)

In [94]:
block_11_settings = {
    "in" : ("IN",),
    "level" : 1,
    "backbone" : nn.Conv3d(1, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_21_settings = {
    "in" : ("b11",),
    "level" : 2,
    "backbone" : nn.Conv3d(3, 3, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_12_settings = {
    "in" : ("b11", "b21"),
    "level" : 1,
    "backbone" : nn.Conv3d(6, 1, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

block_out_settings = {
    "in" : ("b12",),
    "level" : 1,
    "backbone" : nn.Conv3d(1, 1, kernel_size=3, stride=1, padding=1), #nn.Identity(),
}

net_blocks = { 
    "b11" : NetBlock(block_11_settings),
    "b21" : NetBlock(block_21_settings),
    "b12" : NetBlock(block_12_settings),
    "out" : NetBlock(block_out_settings),
}

#net_blocks

In [95]:
class Net(nn.Module):
    def __init__(self, net_blocks, net_graph):
        super(Net, self).__init__()
        self.net_blocks = torch.nn.ModuleDict(net_blocks)
        self.net_graph = (net_graph)
        
    def forward(self, IN):
        block_outs = {"IN" : IN}
        
        for vertes in self.net_graph.keys():
            for connection in self.net_graph[vertes]:
                block_name = connection.connect_to
                if (block_name == "OUT"):
                    pass 
                
                elif (block_name not in block_outs.keys()):
                    block = self.net_blocks[block_name]
                    print(block_name)
                    
                    in_list = []
                    
                    #if block_outs.get(out) = None: continue 
                    for input_block in block.in_blocks:
                        resampled_out = connection.backbone(block_outs[input_block])
                        print(resampled_out.shape)
                        in_list.append(resampled_out)
                    block_outs.update({block_name : (block(in_list))})
        #print("\n block outputs:", block_outs)
        return(block_outs["out"])

In [96]:
x = torch.rand(1, 1, 64, 64, 64)
GT = torch.rand(1, 1, 64, 64, 64) 

In [97]:
model = Net(net_blocks, net_graph)
print("total_params:", get_total_params(model))

total_params: 521


In [98]:
model.to('cuda')
1

1

In [99]:
loss_fn = nn.MSELoss()
opt = torch.optim.Adam(model.parameters())
n_epochs = 100

In [100]:
t = torch.rand(1, 1, 32, 32, 32)
print(t.shape)
up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)
print(up(t).shape)

torch.Size([1, 1, 32, 32, 32])
torch.Size([1, 1, 64, 64, 64])


In [102]:
for epoch in range(n_epochs):
    x_cuda = x.to('cuda')
    GT_cuda = GT.to('cuda')
    
    out = model.forward(x_cuda)   
    loss = loss_fn(GT_cuda, out)

    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

b11
torch.Size([1, 1, 64, 64, 64])
b21
torch.Size([1, 3, 32, 32, 32])
b12
torch.Size([1, 3, 64, 64, 64])
torch.Size([1, 3, 32, 32, 32])


RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 64 but got size 32 for tensor number 1 in the list.