In [114]:
import torch
import torchvision
from torch import nn

class OneHeadModel(nn.Module):
    def __init__(self, device, p_dropout):
        super(OneHeadModel, self).__init__()

        self.device = device
        self.p_dropout = p_dropout

        # Load EfficientNet encoder
        # weights = torchvision.models.EfficientNet_B1_Weights.DEFAULT
        # efficientNet = torchvision.models.efficientnet_b1(weights=weights)
        # self.encoder = efficientNet.features

        # Load EfficientNet encoder
        denseNet = torchvision.models.densenet121(weights='DEFAULT')
        self.encoder = denseNet.features

        # Pooling layers
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)

        # Classification head
        self.classification_head = nn.Sequential(
            nn.Dropout(p=self.p_dropout),
            nn.Linear(1024, 5) # 5 output nodes for classification
            )     

    def forward(self, x):
        x = self.encoder(x) # Extract features

        # Apply pooling layers
        enc_out = self.global_avg_pool(x).view(x.size(0), -1)

        # Classification branch
        class_out = self.classification_head(enc_out).float()

        return class_out, enc_out


In [115]:
model = OneHeadModel(device=torch.device, p_dropout=0.4)

# model = torchvision.models.densenet121(weights='DEFAULT')
# model = model.features

In [116]:
from torchinfo import summary

In [117]:
# Print a summary using torchinfo (uncomment for actual output)
torch.manual_seed(33)
summary(model=model, 
        input_size=(32, 3, 224, 224), # make sure this is "input_size", not "input_shape"
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
) 

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
OneHeadModel (OneHeadModel)                   [32, 3, 224, 224]    [32, 5]              --                   True
├─Sequential (encoder)                        [32, 3, 224, 224]    [32, 1024, 7, 7]     --                   True
│    └─Conv2d (conv0)                         [32, 3, 224, 224]    [32, 64, 112, 112]   9,408                True
│    └─BatchNorm2d (norm0)                    [32, 64, 112, 112]   [32, 64, 112, 112]   128                  True
│    └─ReLU (relu0)                           [32, 64, 112, 112]   [32, 64, 112, 112]   --                   --
│    └─MaxPool2d (pool0)                      [32, 64, 112, 112]   [32, 64, 56, 56]     --                   --
│    └─_DenseBlock (denseblock1)              [32, 64, 56, 56]     [32, 256, 56, 56]    --                   True
│    │    └─_DenseLayer (denselayer1)         [32, 64, 56, 56]     [32, 32, 56, 56]    

In [84]:
def ordinal_labels(y, num_classes):
    """Convert labels to cumulative one-hot encoding"""
    y_cumulative = torch.zeros(len(y), num_classes)
    for i in range(num_classes):
        y_cumulative[:, i] = (y >= i).float()
    return y_cumulative

In [107]:
y = torch.tensor([2, 3, 1, 0, 1, 4])
ord_out = torch.tensor([[3, 5, 1, -7, -6], [8, -4, -3, -10, -3]])

In [108]:
ordinal_labels(y, 5)

tensor([[1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 0.],
        [1., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1.]])

In [109]:
torch.sigmoid(ord_out)

tensor([[9.5257e-01, 9.9331e-01, 7.3106e-01, 9.1105e-04, 2.4726e-03],
        [9.9966e-01, 1.7986e-02, 4.7426e-02, 4.5398e-05, 4.7426e-02]])

In [112]:
y_pred_ord = torch.sum(torch.round(torch.sigmoid(ord_out)), dim=1, keepdim=True).squeeze(dim=1) -1 


In [113]:
y_pred_ord

tensor([2., 0.])

In [75]:
import numpy as np

y_train = np.array([[0, 0, 1, 0, 0], [0, 1, 0, 0, 0]])
y_train_multi = np.empty(y_train.shape, dtype=y_train.dtype)
y_train_multi[:, 4] = y_train[:, 4]

for i in range(3, -1, -1):
    y_train_multi[:, i] = np.logical_or(y_train[:, i], y_train_multi[:, i+1])

print("Original y_train:", y_train.sum(axis=0))
print("Multilabel version:", y_train_multi.sum(axis=0))

Original y_train: [0 1 1 0 0]
Multilabel version: [2 2 1 0 0]


In [76]:
print("Original y_train:", y_train)
print("Multilabel version:", y_train_multi)

Original y_train: [[0 0 1 0 0]
 [0 1 0 0 0]]
Multilabel version: [[1 1 1 0 0]
 [1 1 0 0 0]]
