In [119]:
import torch
import torch.nn
import torch.optim
import numpy as np

torch.manual_seed(69)

n = 10 # Batch size
p = 6 # Number of input features
c = 3 # Number of output features (number of organ classes)

def L_y(y_pred: torch.tensor, y_true: torch.tensor, d: torch.tensor):
  losses = torch.nn.functional.cross_entropy(y_pred, y_true, reduction='none')
  losses = losses * (1-d) # Set losses of unlabeled instances to 0
  return torch.mean(losses)

X = torch.randn(n, p, requires_grad=True) # Batch of input features
y_true = torch.randint(0, c, (n,)) # Batch of class labels (organ class outputs)
d = torch.randint(0, 2, (n,)) # Batch of domain labels

linear_layer = torch.nn.Linear(p, c)
y_pred = linear_layer(X)
L = L_y(y_pred, y_true, d)

print("y_pred = X W^T + b; L(y_pred, y_true, d)")
print("X=\n", X)
print("y_true=", y_true)
print("y_pred=\n", y_pred)
print("d=", d)
print("W=\n", linear_layer.weight.data)
print("b=", linear_layer.bias.data)

L.backward() # Calculate the gradients
print("dL/dX=\n", X.grad)

zero_gradient_rows = torch.where(torch.sum(torch.abs(X.grad), dim=1) == 0)
print("Row indices of X.grad with all zeros:\n\t", zero_gradient_rows)
print("Samples for which d is 1:\n\t", torch.where(d == 1))

y_pred = X W^T + b; L(y_pred, y_true, d)
X=
 tensor([[-0.5300, -1.3035,  0.4438,  1.2221,  1.0395,  0.9608],
        [ 0.4214,  0.7452, -1.8389, -1.2497, -0.2485,  0.1428],
        [-1.0509,  0.3527, -0.0916,  0.0341, -0.8986,  0.1022],
        [-0.6627, -0.1350, -0.3983, -1.7892,  1.2785,  1.3351],
        [-0.3066,  1.0382,  1.2762,  0.0419, -1.2794, -1.8432],
        [ 0.8633, -1.7786, -0.8080, -0.8735,  0.9367, -1.2319],
        [ 1.5287, -0.2759, -0.8625, -0.1915, -0.4807, -1.4154],
        [ 0.0934, -0.2420,  1.0251,  1.3822,  2.1080,  0.2562],
        [ 0.4913,  1.0152, -0.0184, -0.8487, -1.6520, -1.1392],
        [ 0.6818,  0.4731, -0.4292, -1.0216, -0.5285,  1.6272]],
       requires_grad=True)
y_true= tensor([0, 1, 2, 0, 1, 2, 2, 1, 2, 1])
y_pred=
 tensor([[-0.3210, -0.3229,  0.7304],
        [-0.6966, -0.3994, -1.3843],
        [-0.4325, -1.0315, -0.0327],
        [-1.0576, -0.1506, -0.4659],
        [ 0.3075, -0.5426,  0.5685],
        [-0.2938,  1.2050, -0.1878],
        [