In [1]:
from nnunet_ext.network_architecture.nca.NCA2D import NCA2D
import torch
import torch.nn.functional as F
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7fe7bd9c0bb0>

In [2]:
device = "cuda"
x = torch.randn(1,1,4,4, device=device)

In [3]:
hidden_size0 = 64
nca0 = NCA2D(num_channels=32, num_input_channels=1, num_classes=0, hidden_size=hidden_size0, fire_rate=1, num_steps=10, use_norm=False)
nca0.to(device)

hidden_size1 = 64
nca1 = NCA2D(num_channels=32, num_input_channels=0, num_classes=0, hidden_size=hidden_size1, fire_rate=1, num_steps=10, use_norm=False)
nca1.fc0 = torch.nn.Conv2d(2 * (nca0.num_channels + nca1.num_channels), hidden_size1, kernel_size=1)
nca1.to(device);

In [4]:
state = nca0.make_state(x)
out0 = nca0.forward_internal(state)

In [5]:
def extended_nca_update(base_nca, new_nca, states):
    delta_state0 = base_nca.conv(states[0])
    delta_state1 = new_nca.conv(states[1])
    delta_state1 = torch.cat([states[0], states[1], delta_state0, delta_state1], dim=1)
    delta_state0 = torch.cat([states[0], delta_state0], dim=1)
    delta_state0 = base_nca.fc0(delta_state0)
    delta_state1 = new_nca.fc0(delta_state1)
    delta_state0 = base_nca.batch_norm(delta_state0)
    delta_state1 = new_nca.batch_norm(delta_state1)
    delta_state0 = F.relu(delta_state0, inplace=False)
    delta_state1 = F.relu(delta_state1, inplace=False)
    delta_state0 = base_nca.fc1(delta_state0)
    delta_state1 = new_nca.fc1(delta_state1)
    temp_state0 = states[0][:, base_nca.num_input_channels:] + delta_state0
    temp_state0 = torch.cat([states[0][:, :base_nca.num_input_channels], temp_state0], dim=1)
    return [temp_state0, states[1] + delta_state1]

In [8]:
def extended_nca_updated_merged(base_nca, new_nca, states):
    merged_nca = NCA2D(num_channels=base_nca.num_channels + new_nca.num_channels, num_input_channels=base_nca.num_input_channels, 
                       num_classes=16, hidden_size=hidden_size0 + hidden_size1, fire_rate=1, num_steps=10, use_norm=False).to(device)

    merged_nca.conv.weight = torch.nn.Parameter(torch.cat([base_nca.conv.weight, new_nca.conv.weight], dim=0))
    merged_nca.conv.bias = torch.nn.Parameter(torch.cat([base_nca.conv.bias, new_nca.conv.bias], dim=0))

    base_nca_fc0_upper = base_nca.fc0.weight[:, :base_nca.num_channels, :, :]
    base_nca_fc0_lower = base_nca.fc0.weight[:, base_nca.num_channels:, :, :]

    base_nca_fc0_w_zeros = torch.cat([base_nca_fc0_upper, torch.zeros(hidden_size0, new_nca.num_channels, 1, 1, device=device),
                                      base_nca_fc0_lower, torch.zeros(hidden_size0, new_nca.num_channels, 1, 1, device=device)], dim=1)
    fc0 = torch.cat([base_nca_fc0_w_zeros, new_nca.fc0.weight], dim=0)
    assert merged_nca.fc0.weight.shape == fc0.shape, f"{merged_nca.fc0.weight.shape}, {base_nca_fc0_w_zeros.shape}"
    merged_nca.fc0.weight = torch.nn.Parameter(fc0)
    merged_nca.fc0.bias = torch.nn.Parameter(torch.cat([base_nca.fc0.bias, new_nca.fc0.bias], dim=0))


    base_nca_fc1_w_zeros = torch.cat([base_nca.fc1.weight, torch.zeros(base_nca.num_channels-base_nca.num_input_channels, hidden_size1, 1, 1, device=device)], dim=1)
    new_nca_fc1_w_zeros = torch.cat([torch.zeros(new_nca.num_channels, hidden_size0, 1, 1, device=device), new_nca.fc1.weight], dim=1)
    fc1 = torch.cat([base_nca_fc1_w_zeros, new_nca_fc1_w_zeros], dim=0)
    assert merged_nca.fc1.weight.shape == fc1.shape, f"{merged_nca.fc1.weight.shape}, {fc1.shape}"
    merged_nca.fc1.weight = torch.nn.Parameter(fc1)
    assert base_nca.fc1.bias is None and new_nca.fc1.bias is None
    assert merged_nca.fc1.bias is None


    temp = merged_nca.conv(states)
    temp = torch.cat([states, temp], dim=1)
    temp = merged_nca.fc0(temp)
    assert isinstance(base_nca.batch_norm, torch.nn.Identity)
    assert isinstance(new_nca.batch_norm, torch.nn.Identity)
    temp = F.relu(temp, inplace=False)
    temp = merged_nca.fc1(temp)

    temp = temp + states[:, base_nca.num_input_channels:]
    return temp

In [9]:
state1 = torch.zeros(x.shape[0], nca0.num_channels - nca0.num_input_channels,
                    x.shape[2], x.shape[3], device=x.device)
state2 = torch.zeros(x.shape[0], nca1.num_channels, x.shape[2], x.shape[3], device=x.device)
states = [torch.cat([x, state1], dim=1), state2]


states = torch.cat([states[0], states[1]], dim=1)

for _ in range(nca0.num_steps):
    states = extended_nca_updated_merged(nca0, nca1, states)
    states = torch.cat([x, states], dim=1)

In [10]:
(out0 == states[:,:32]).all(), torch.allclose(out0, states[:,:32])

(tensor(True, device='cuda:0'), True)

In [11]:
out0

tensor([[[[-6.1646e-01, -3.8514e-03,  3.1803e-01, -4.4318e-01],
          [-2.2502e-01,  3.1317e-01,  6.5186e-01, -5.3522e-01],
          [ 8.4361e-01, -1.7829e+00,  1.3383e+00,  6.7325e-01],
          [-9.0178e-01, -1.2204e+00,  2.0156e+00, -5.5390e-01]],

         [[ 1.5078e+00,  1.5343e+00,  1.2887e+00,  1.5620e+00],
          [ 1.5130e+00,  1.6804e+00,  2.0005e+00,  1.7453e+00],
          [ 1.7658e+00,  2.0133e+00,  2.0679e+00,  2.1928e+00],
          [ 1.6515e+00,  1.8816e+00,  2.1945e+00,  1.7193e+00]],

         [[ 3.8411e-02, -6.2814e-02, -2.1936e-01, -5.7471e-03],
          [ 1.0628e-01, -2.7259e-01, -2.4212e-01,  1.4507e-01],
          [-5.1207e-01,  1.1675e-02, -7.4097e-01, -7.8765e-02],
          [ 2.5764e-01, -2.0948e-01, -1.0839e+00,  2.9660e-01]],

         [[ 8.7707e-03,  2.4454e-01,  9.6802e-02,  1.6522e-01],
          [ 1.1759e-01, -1.3448e-02, -1.2778e-01,  3.5558e-01],
          [-7.0183e-02,  2.5259e-01, -1.6833e-01,  7.5967e-02],
          [ 3.4549e-01,  1.2975e-0

In [12]:
states.shape, out0.shape

(torch.Size([1, 64, 4, 4]), torch.Size([1, 32, 4, 4]))

In [13]:
(states[:,:32] - out0).abs().max()

tensor(0., device='cuda:0')

In [14]:
states

tensor([[[[-0.6165, -0.0039,  0.3180, -0.4432],
          [-0.2250,  0.3132,  0.6519, -0.5352],
          [ 0.8436, -1.7829,  1.3383,  0.6733],
          [-0.9018, -1.2204,  2.0156, -0.5539]],

         [[ 1.5078,  1.5343,  1.2887,  1.5620],
          [ 1.5130,  1.6804,  2.0005,  1.7453],
          [ 1.7658,  2.0133,  2.0679,  2.1928],
          [ 1.6515,  1.8816,  2.1945,  1.7193]],

         [[ 0.0384, -0.0628, -0.2194, -0.0057],
          [ 0.1063, -0.2726, -0.2421,  0.1451],
          [-0.5121,  0.0117, -0.7410, -0.0788],
          [ 0.2576, -0.2095, -1.0839,  0.2966]],

         ...,

         [[ 0.6761,  0.7561,  0.6364,  0.7024],
          [ 0.7121,  0.6493,  0.9677,  0.7729],
          [ 0.4524,  1.0342,  0.9438,  0.9540],
          [ 0.6062,  0.8332,  0.9915,  0.9807]],

         [[ 0.1580,  0.3982,  0.6170,  0.0700],
          [ 0.3999,  0.6019,  0.8684, -0.0236],
          [ 1.0192, -0.1615,  1.4814,  0.6664],
          [ 0.5339, -0.1314,  1.7835, -0.0698]],

         [[ 0.3