In [5]:
import torch
from torch import nn
import torch.nn.functional as F

from torchinfo import summary

In [99]:
class DropConnectLinear(nn.Linear):
    '''
    Implements a drop connection before a linear layer. 
    This is based on code from torchnlp: 
        https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/weight_drop.html
    Note: The torchnlp implementation may not follow the original drop connect paper. 
          For one, it uses inverted dropout.
    Original Paper: https://proceedings.mlr.press/v28/wan13.html
    
    Args:
        in_features (int): Number of features in the input.
        out_features (int): Number of features in the output.
        drop_prob (float): Probability of a weight being dropped during training.
    '''
    def __init__(self, in_features: int, out_features: int, drop_prob: float = 0.5, **kwargs):
        super().__init__(in_features, out_features, **kwargs)
        self.drop_prob = drop_prob

    def forward(self, X):
        # Drop random weights during training
        drop_weights = F.dropout(self.weight, p = self.drop_prob, training = self.training)
        
        return F.linear(X, drop_weights, self.bias)

class ConvBNDrop(nn.Module):
    def __init__(self, 
                 out_channels: int, 
                 kernel_size: int = 3 , 
                 drop_prob: float = 0.35, 
                 pre_dropout: bool = False):
        super().__init__()
        
        layers = [nn.LazyConv2d(out_channels, kernel_size), nn.ReLU(),
                  nn.BatchNorm2d(out_channels),
                  nn.Dropout(drop_prob)]
        
        if pre_dropout:
            # Move the last element (Dropout) to the front
            layers = [layers[-1]] + layers[:-1]
        
        self.conv_bn_drop = nn.Sequential(*layers)
        
    def forward(self, X):
        return self.conv_bn_drop(X)

class EnsNetBaseCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_body = nn.Sequential()
        
        self.cnn_body.add_module(
            'cnn_block_1',
            nn.Sequential(
                ConvBNDrop(64, kernel_size = 3, drop_prob = 0.35),
                ConvBNDrop(128, kernel_size = 3, drop_prob = 0.35),

                nn.LazyConv2d(256, kernel_size = 3), nn.ReLU(),
                nn.BatchNorm2d(256),
                nn.MaxPool2d(kernel_size = 2)
            )
        )
        
        self.cnn_body.add_module(
            'cnn_block_2',
            nn.Sequential(
                ConvBNDrop(512, kernel_size = 3, drop_prob = 0.35, pre_dropout = True),
                ConvBNDrop(1024, kernel_size = 3, drop_prob = 0.35, pre_dropout = True),
                ConvBNDrop(2000, kernel_size = 3, drop_prob = 0.35, pre_dropout = True),

                nn.MaxPool2d(kernel_size = 2),
                nn.Dropout(0.35)
            )  
        )
        
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.LazyLinear(512), nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),
            DropConnectLinear(512, 512, drop_prob = 0.5), nn.ReLU(),
            nn.LazyLinear(10)
        )
        
    def forward(self, X):
        final_feature_maps = self.cnn_body(X)
        logits = self.classifier(final_feature_maps)
        
        return logits, final_feature_maps

class EnsNetFCSN(nn.Module):
    def __init__(self):
        super().__init__()
        self.subnet = nn.Sequential(
            nn.LazyLinear(512), nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.5),
            DropConnectLinear(512, 512, drop_prob = 0.5), nn.ReLU(),
            nn.LazyLinear(10)
        )
        
    def forward(self, X):
        return self.subnet(X)
        
class EnsNet(nn.Module):
    def __init__(self, num_subnets):
        super().__init__()
        
        self.num_subnets = num_subnets
        
        # Create base CNN to extract features before division
        self.base_cnn = EnsNetBaseCNN()
        
        self.subnets = nn.ModuleList([
            EnsNetFCSN() for _ in range(num_subnets)
        ])
    
    def predict(self, X):
        # Get logits from the base CNN and subnets
        cnn_logits, subnet_logits, _ = self.forward(X)
        all_logits = [cnn_logits] + subnet_logits # num_voters = num_subnets + 1

        # Get predicted classes from the base CNN and each subnet
        pred_classes = torch.stack(all_logits, dim = 0).argmax(dim = -1) # Shape: (num_voters, batch_size)
 
        # Get majority vote among predicted classes
        return torch.mode(pred_classes, dim = 0).values # Mode along voter dimension

    def forward(self, X):
        # Get logits and feature maps from base CNN
        cnn_logits, cnn_feat_maps = self.base_cnn(X)

        # Divide the feature maps into distinct chunks
            # If cnn_feat_maps[1] isn't divisible by num_subnets, the last chunk will have less channels
        div_feat_maps = torch.chunk(cnn_feat_maps, chunks = self.num_subnets, dim = 1)

        # Get logits from subnetworks
        subnet_logits = []
        for feat_map, subnet in zip(div_feat_maps, self.subnets):
            feat_map = torch.flatten(feat_map, start_dim = 1) # Flatten feature map
            subnet_logits.append(subnet(feat_map))

        return cnn_logits, subnet_logits, cnn_feat_maps

In [100]:
ensnet = EnsNet(num_subnets = 10)
X_dummy = torch.rand(64, 1, 28, 28) # Mimic MNIST batch size

# ensnet.eval()
# with torch.inference_mode():
#     pred_labels = ensnet.predict(X_dummy)
    
summary(model = ensnet,
        input_size = X_dummy.shape,
        col_names = ['input_size', 'output_size', 'num_params'],
        col_width = 20,
        row_settings = ['var_names'])

Layer (type (var_name))                                 Input Shape          Output Shape         Param #
EnsNet (EnsNet)                                         [64, 1, 28, 28]      [64, 10]             --
├─EnsNetBaseCNN (base_cnn)                              [64, 1, 28, 28]      [64, 10]             --
│    └─Sequential (cnn_body)                            [64, 1, 28, 28]      [64, 2000, 2, 2]     --
│    │    └─Sequential (cnn_block_1)                    [64, 1, 28, 28]      [64, 256, 11, 11]    370,560
│    │    └─Sequential (cnn_block_2)                    [64, 256, 11, 11]    [64, 2000, 2, 2]     24,340,848
│    └─Sequential (classifier)                          [64, 2000, 2, 2]     [64, 10]             --
│    │    └─Flatten (0)                                 [64, 2000, 2, 2]     [64, 8000]           --
│    │    └─Linear (1)                                  [64, 8000]           [64, 512]            4,096,512
│    │    └─ReLU (2)                                    [64, 512] 