In [1]:
import torch
import torch.nn.functional as F

In [2]:
# Consider 8 training examples
X_train = torch.rand(8, 99)
y_train = torch.tensor([0, 1, 2, 2, 2, 3, 4, 4])

NUM_CLASSES = 5

# Niu

In [3]:
def label_to_levels(label, num_classes):
    levels = [1]*label + [0]*(num_classes - 1 - label)
    levels = torch.tensor(levels, dtype=torch.float32)
    return levels

levels = []

for label in y_train:
    levels_from_label = label_to_levels(label, num_classes=NUM_CLASSES)
    levels.append(levels_from_label)

levels = torch.stack(levels)
levels

tensor([[0., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [4]:
# for simplicity, this is the ordinal network:

# def __init__()
niu = torch.nn.Linear(99, NUM_CLASSES-1)

# def forward(self, X_train):
logits = niu(X_train)
logits.shape

torch.Size([8, 4])

In [5]:
def loss_niu(logits, levels):
    val =  -torch.sum((torch.log(torch.sigmoid(logits))*levels + 
                       torch.log(1 - torch.sigmoid(logits))*(1-levels)),
                      dim=1)
    return torch.mean(val)



loss_niu(logits, levels)

tensor(2.6742, grad_fn=<MeanBackward0>)

# Conditional

In [6]:
sets = []
for i in range(NUM_CLASSES-1):
    label_mask = y_train > i-1
    label_tensor = (y_train[label_mask] > i).to(torch.int64)
    sets.append((label_mask, label_tensor))
sets

[(tensor([True, True, True, True, True, True, True, True]),
  tensor([0, 1, 1, 1, 1, 1, 1, 1])),
 (tensor([False,  True,  True,  True,  True,  True,  True,  True]),
  tensor([0, 1, 1, 1, 1, 1, 1])),
 (tensor([False, False,  True,  True,  True,  True,  True,  True]),
  tensor([0, 0, 0, 1, 1, 1])),
 (tensor([False, False, False, False, False,  True,  True,  True]),
  tensor([0, 1, 1]))]

In [7]:
# same as Niu et al.

# def __init__(self):
conditional_net = torch.nn.Linear(99, NUM_CLASSES-1)

# def forward(self, X_train):
logits = conditional_net(X_train)
logits.shape

torch.Size([8, 4])

In [8]:
def loss_conditional1(logits, sets):
    
    losses = []
    for task_index, s in enumerate(sets):
        train_examples = s[0]
        train_labels = s[1]
        pred = logits[train_examples, task_index]
        
        if len(s[1]) < 1:
            continue        
        
        loss = -torch.mean(torch.log(torch.sigmoid(pred))*train_labels + 
                           torch.log(1 - torch.sigmoid(pred))*(1-train_labels))
        
        losses.append(loss)
        
    return losses


losses = loss_conditional1(logits, sets)

for i, loss in enumerate(losses):
    print(f'task {i} loss: {loss}')

task 0 loss: 0.9681528806686401
task 1 loss: 0.8592419624328613
task 2 loss: 0.7534388899803162
task 3 loss: 0.5798597931861877


In [9]:
# More stable  implementation combining log and sigmoid

def loss_conditional2(logits, sets):
    
    losses = []
    for task_index, s in enumerate(sets):
        train_examples = s[0]
        train_labels = s[1]
        pred = logits[train_examples, task_index]
        
        if len(s[1]) < 1:
            continue
        
        loss = -torch.mean( F.logsigmoid(pred)*train_labels 
                           + (F.logsigmoid(pred) - pred)*(1-train_labels) )
        
        losses.append(loss)
        
    return losses


losses = loss_conditional2(logits, sets)

for i, loss in enumerate(losses):
    print(f'task {i} loss: {loss}')

task 0 loss: 0.9681528806686401
task 1 loss: 0.8592420220375061
task 2 loss: 0.7534387707710266
task 3 loss: 0.579859733581543


average over training examples:

In [10]:
def loss_conditional_v2(logits, sets):

    num_examples = 0
    losses = 0.
    for task_index, s in enumerate(sets):
        train_examples = s[0]
        train_labels = s[1]

        if len(train_labels) < 1:
            continue

        num_examples += len(train_labels)
        pred = logits[train_examples, task_index]
        
        loss = -torch.sum( F.logsigmoid(pred)*train_labels
                                + (F.logsigmoid(pred) - pred)*(1-train_labels) )
        losses += loss
    return losses/num_examples


losses = loss_conditional_v2(logits, sets)

losses

tensor(0.8342, grad_fn=<DivBackward0>)

# Conditional with branches

In [11]:
# Before:

conditional_net = torch.nn.Linear(99, NUM_CLASSES-1)
logits = conditional_net(X_train)
logits.shape

torch.Size([8, 4])

In [12]:
# def __init__():
branches = [torch.nn.Linear(99, 10) for i in range(NUM_CLASSES-1)]
output_layers = [torch.nn.Linear(10, 1) for i in range(NUM_CLASSES-1)]

# def forward(self, X_train):
logits = torch.empty(X_train.shape[0], len(output_layers))

for i in range(len(output_layers)):
    logits[:, i] = output_layers[i](torch.relu(branches[i](X_train))).squeeze()

logits.shape

torch.Size([8, 4])

In [13]:
losses = loss_conditional1(logits, sets)

for i, loss in enumerate(losses):
    print(f'task {i} loss: {loss}')

task 0 loss: 0.7900136113166809
task 1 loss: 0.7687548398971558
task 2 loss: 0.6789558529853821
task 3 loss: 0.6958199143409729


In [14]:
losses = loss_conditional2(logits, sets)

for i, loss in enumerate(losses):
    print(f'task {i} loss: {loss}')

task 0 loss: 0.7900136709213257
task 1 loss: 0.7687548398971558
task 2 loss: 0.6789558529853821
task 3 loss: 0.6958198547363281


In [15]:
# def __init__():
branches = [torch.nn.Linear(99, 10) for i in range(NUM_CLASSES-1)]
output_layers = [torch.nn.Linear(10, 1) for i in range(NUM_CLASSES-1)]

# def forward(self, X_train):
logits = []

for i in range(len(output_layers)):
    logits.append(output_layers[i](torch.relu(branches[i](X_train))).squeeze())

torch.stack(logits).T.shape

torch.Size([8, 4])

# Usage in Forward Pass

In [16]:
# same as Niu et al.

# def __init__(self):
conditional_net = torch.nn.Linear(99, NUM_CLASSES-1)

# def forward(self, X_train):
logits = conditional_net(X_train)
logits

tensor([[ 0.1821, -0.3466,  0.0601,  0.3062],
        [ 0.4080, -0.1538,  0.2523,  0.6325],
        [ 0.3554, -0.1273,  0.2426,  0.6384],
        [ 0.2563, -0.3050,  0.3775,  0.5575],
        [ 0.4663, -0.3350, -0.0339,  0.2436],
        [ 0.5024, -0.1526, -0.1593,  0.2748],
        [ 0.3983, -0.2236, -0.0391,  0.1512],
        [ 0.2917, -0.4038, -0.0181,  0.2092]], grad_fn=<AddmmBackward>)

In [17]:
logits.shape

torch.Size([8, 4])

### Conditional v1

with `cumprod`

In [18]:
probas = torch.sigmoid(logits)
probas = torch.cumprod(probas, dim=1)

probas

tensor([[0.5454, 0.2259, 0.1163, 0.0670],
        [0.6006, 0.2773, 0.1560, 0.1019],
        [0.5879, 0.2753, 0.1542, 0.1009],
        [0.5637, 0.2392, 0.1419, 0.0902],
        [0.6145, 0.2563, 0.1260, 0.0706],
        [0.6230, 0.2878, 0.1325, 0.0753],
        [0.5983, 0.2658, 0.1303, 0.0701],
        [0.5724, 0.2292, 0.1136, 0.0627]], grad_fn=<CumprodBackward>)

In [19]:
def proba_to_label(probas):
    """
    Converts predicted probabilities from extended binary format
    to integer class labels

    Parameters
    ----------
    probas : torch.tensor, shape(n_examples, n_labels)
        Torch tensor consisting of probabilities returned by CORAL model.

    Examples
    ----------
    >>> # 3 training examples, 6 classes
    >>> probas = torch.tensor([[0.934, 0.861, 0.323, 0.492, 0.295],
    ...                        [0.496, 0.485, 0.267, 0.124, 0.058],
    ...                        [0.985, 0.967, 0.920, 0.819, 0.506]])
    >>> proba_to_label(probas)
    tensor([2, 0, 5])
    """
    predict_levels = probas > 0.5
    predicted_labels = torch.sum(predict_levels, dim=1)
    return predicted_labels

predicted_labels = proba_to_label(probas).float()
predicted_labels

tensor([1., 1., 1., 1., 1., 1., 1., 1.])

### Conditional v2

like v1 but without `cumprod`

In [20]:
probas = torch.sigmoid(logits)
probas

tensor([[0.5454, 0.4142, 0.5150, 0.5760],
        [0.6006, 0.4616, 0.5627, 0.6531],
        [0.5879, 0.4682, 0.5603, 0.6544],
        [0.5637, 0.4243, 0.5933, 0.6359],
        [0.6145, 0.4170, 0.4915, 0.5606],
        [0.6230, 0.4619, 0.4603, 0.5683],
        [0.5983, 0.4443, 0.4902, 0.5377],
        [0.5724, 0.4004, 0.4955, 0.5521]], grad_fn=<SigmoidBackward>)

In [21]:
predicted_labels = proba_to_label(probas).float()
predicted_labels

tensor([3., 3., 3., 3., 2., 2., 2., 2.])

### Conditional v2-argmax (Debug)

calculate the probs of (y=0, y=1, y=2,…) using (1-P(y>0), P(y>0)(1-P(y>1|y>0), P(y>0)P(y>1|y>0)(1-P(y>2|y>1),…), and then using argmax to find the predicted label.

In [22]:
probas = torch.sigmoid(logits)
probas

tensor([[0.5454, 0.4142, 0.5150, 0.5760],
        [0.6006, 0.4616, 0.5627, 0.6531],
        [0.5879, 0.4682, 0.5603, 0.6544],
        [0.5637, 0.4243, 0.5933, 0.6359],
        [0.6145, 0.4170, 0.4915, 0.5606],
        [0.6230, 0.4619, 0.4603, 0.5683],
        [0.5983, 0.4443, 0.4902, 0.5377],
        [0.5724, 0.4004, 0.4955, 0.5521]], grad_fn=<SigmoidBackward>)

In [23]:
ones = torch.ones((probas.shape[0], 1))#.to(device)
comp_1 = torch.cat((ones, torch.cumprod(probas, dim=1)), dim=1)
comp_2 = torch.cat((1-probas, ones), dim=1)
probas_y = torch.mul(comp_1, comp_2)
probas_y


tensor([[0.4546, 0.3195, 0.1096, 0.0493, 0.0670],
        [0.3994, 0.3234, 0.1212, 0.0541, 0.1019],
        [0.4121, 0.3127, 0.1210, 0.0533, 0.1009],
        [0.4363, 0.3245, 0.0973, 0.0517, 0.0902],
        [0.3855, 0.3583, 0.1303, 0.0553, 0.0706],
        [0.3770, 0.3352, 0.1553, 0.0572, 0.0753],
        [0.4017, 0.3324, 0.1355, 0.0602, 0.0701],
        [0.4276, 0.3432, 0.1156, 0.0509, 0.0627]], grad_fn=<MulBackward0>)

In [24]:
predicted_labels = torch.argmax(probas_y, dim=1)
predicted_labels

tensor([0, 0, 0, 0, 0, 0, 0, 0])

### Conditional v3

Same as above, but in code, average over training examples rather than tasks when computing the loss.

### Conditional v3

same as `Conditional v2-argmax` but in code, average over training examples rather than tasks when computing the loss.