Pytorch supports modifying the gradient of a tensor by registring a hook, this may be the simpliest approach
to integrating the leaf classification network. 

The goal is to add linearly the leaf classification term to the existing cross entropy loss. The leaf classification term
will be non-differentiable with respect to the parameters for a few reasons:

We can apply the leaf classifier to the output of the semantic segmentation model by
- Grouping points by cluster
- Down sampling to 80 points
- Executing the leaf classification model on each cluster
- Adding 1 to the loss for the point if clustered correctly
- Removing 1 from the loss if clustered incorrectly

$$
\mathcal{L}_{leaf} = \frac{\sum_k N_k * C(ds(X_k), ds(Y_k), ds(Z_k))}{N_{total}}

$$

$$
\mathcal{L}_{aug} = \mathcal{L}(w, b | X, Y_{act}, Y_{pred}) +  \mathcal{L}_{leaf}(X, Y_{act}, Y_{pred})
$$

1. In the current implementation of the leaf classifier the points for a particular leaf cluster must be downsampled to apply the leaf classifier.
2. Even if not down-sampled the leaf classifier must be applied to a group of points. The grouping operation itself is non-differentiable 

The gradient update step would be 

$$
\begin{align}
W' &= W - \eta \nabla \mathcal{L}_{aug}(W | X, Y_{act}, Y_{pred}) \\
   &= W - \eta (\nabla \mathcal{L}_{ce}(W | X, Y_{act}, Y_{pred}) + \nabla \mathcal{L}_{leaf}(W | X, Y_{act}, Y_{pred}))
\end{align}
$$

However this term:

$$
\nabla \mathcal{L}_{leaf}(W | X, Y_{act}, Y_{pred})
$$

Has no explicit expression. 

In [1]:
# Example of gather operation usage

import torch

t = torch.tensor([
  [
    [1, 2, 3], 
    [3, 4, 5], 
    [5, 6, 7]
  ], 
  [
    [4, 5, 6],
    [8, 1, 9],
    [2, 1, 2]
  ]
])

torch.gather(t, 1, torch.tensor([
  [
    [1, 1, 1],
    [2, 2, 2]
  ],
  [ 
    [0, 0, 0],
    [1, 1, 1]
  ]
]))

tensor([[[3, 4, 5],
         [5, 6, 7]],

        [[4, 5, 6],
         [8, 1, 9]]])

In [2]:
# Example of repeat operation usage

import numpy as np

test = np.array([[0, 4, 6], [1, 2, 4]])
test_expanded = np.expand_dims(test, axis=2)

np.repeat(test_expanded, 3, axis=2 )

array([[[0, 0, 0],
        [4, 4, 4],
        [6, 6, 6]],

       [[1, 1, 1],
        [2, 2, 2],
        [4, 4, 4]]])

In [3]:
# utility function to generate rows of random choice vectors
def multi_random_choice(samples, s_size, max):

  out = np.zeros((samples, s_size))

  for i in range(samples):
    out[i,:] = np.random.choice(max, s_size, replace=False)

  return out

print(multi_random_choice(4, 8, 20))


[[16.  3.  6. 15. 10. 12.  1.  9.]
 [15.  0. 13. 18. 17. 16.  5. 14.]
 [14. 12. 15. 11. 13.  9. 18.  5.]
 [ 1.  9. 16.  2. 19. 18. 11. 10.]]


In [4]:
# Example of gradient upsampling operation

test_grad = np.ones((2, 3, 3)) 


test = np.array([[0, 4, 6], [1, 2, 4]])
test_expanded = np.expand_dims(test, axis=2)

gather_idx = np.repeat(test_expanded, 3, axis=2 )
gather_idx = torch.tensor(gather_idx, dtype=torch.int64)

test_input = np.random.choice(6, (2, 7, 3))

grad_out = torch.zeros(test_input.shape, dtype=torch.float32)

src = torch.tensor(test_grad, dtype=torch.float32) 
print("gather_idx: ", gather_idx)
print("src: ", src)
print("result: ", grad_out.scatter_(1, gather_idx, src))

gather_idx:  tensor([[[0, 0, 0],
         [4, 4, 4],
         [6, 6, 6]],

        [[1, 1, 1],
         [2, 2, 2],
         [4, 4, 4]]])
src:  tensor([[[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]],

        [[1., 1., 1.],
         [1., 1., 1.],
         [1., 1., 1.]]])
result:  tensor([[[1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.],
         [0., 0., 0.],
         [1., 1., 1.]],

        [[0., 0., 0.],
         [1., 1., 1.],
         [1., 1., 1.],
         [0., 0., 0.],
         [1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.]]])


In [8]:
sys.path.append("/work/murph186/repos")
sys.path.append("/work/murph186/repos/TreePartNet/")

import numpy as np

from torch.autograd import Function
from SorghumPartNet.models.extensions import Downsample

ds = Downsample.apply

test = torch.tensor([
  [
    [1, 2, 3], 
    [3, 4, 5], 
    [5, 6, 7],
    [6, 6, 4]
  ], 
  [
    [4, 5, 6],
    [8, 1, 9],
    [2, 1, 2],
    [3, 2, 4]
  ]
], requires_grad=True, dtype=torch.float32)


o = ds(test, 2)

print("result of downsample: ", o)

result = o.sum()

print("result of sum: ", result)

result.backward()
print("gradient, d test / d result of sum: ", test.grad)

result of downsample:  tensor([[[1., 2., 3.],
         [6., 6., 4.]],

        [[3., 2., 4.],
         [8., 1., 9.]]], grad_fn=<DownsampleBackward>)
result of sum:  tensor(49., grad_fn=<SumBackward0>)
gradient, d test / d result of sum:  tensor([[[1., 1., 1.],
         [0., 0., 0.],
         [0., 0., 0.],
         [1., 1., 1.]],

        [[0., 0., 0.],
         [1., 1., 1.],
         [0., 0., 0.],
         [1., 1., 1.]]])


In [9]:
# example of using .grad property
x = torch.ones(2, 2, requires_grad=True)

y = x + 2
z = y * y * 3
out = z.mean()

print(out)

out.backward()

x.grad

tensor(27., grad_fn=<MeanBackward0>)


tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])