In [12]:
import torch.nn as nn
import torch

In [21]:
class CNN(nn.Module):
    def __init__(self, gates, in_feature):
        super(CNN, self).__init__()
        self.gates = gates
        self.conv_1 = nn.Conv2d(3,64, 3)
        self.conv_2 = nn.Conv2d(64,64, 3)
        self.conv_3 = nn.Conv2d(64,128, 3)
        self.conv_4 = nn.Conv2d(128,128, 3)
        self.conv_5 = nn.Conv2d(128,256, 3)
        self.conv_6 = nn.Conv2d(256,256, 1)
        self.relu = nn.ReLU()
        self.flatten = nn.Flatten()
        self.fc_1 = nn.Linear(in_feature, 512)
        self.fc_2 = nn.Linear(512, 256)
        self.fc_3 = nn.Linear(256, 1)
        self.sigmoid = nn.Sigmoid()
        
        
        
        
    def down_sample(self, x):
        x = nn.MaxPool2d(2)(x)
        x = nn.BatchNorm2d(x.shape[1], affine=False)(x)
        x = nn.ReLU()(x)
        x = nn.Dropout(0.2)(x)
        
        return x
    
    def batch_norm(self, x):
        x =  nn.BatchNorm1d(x.shape[1], affine=False)(x)
        return nn.ReLU()(x)
    
    def conv(self, x):    
        x = self.conv_1(x)
        x = self.down_sample(x)
        
        if self.gates[0]:
            x = self.conv_2(x)
            x = self.down_sample(x)
        if self.gates[1]:
            x = self.conv_3(x)
            x = self.down_sample(x)
        if self.gates[2]:    
            x = self.conv_4(x)
            x = self.down_sample(x)
        if self.gates[3]:
            x = self.conv_5(x)
            x = self.down_sample(x)
        if self.gates[4]:     
            x = self.conv_6(x)
            x = self.down_sample(x)
            
        x = self.flatten(x)
        
        return x
    
    def classifier(self, x):
        x = self.fc_1(x)
        x = self.relu(x)
        x = self.batch_norm(x)
        x = self.fc_2(x)
        x = self.relu(x)
        x = self.batch_norm(x)
        x = self.fc_3(x)
        x = self.sigmoid(x)
        
        return x
    
    def forward(self, x):
        x = self.conv(x)
        x = self.classifier(x)
        
        return x

# Multi CNN

In [22]:
gates = [False, False, False, False, False]
in_features = [350464, 82944, 36992, 6272, 1024, 256]
inputs = torch.randint(0, 255, (32, 3, 150, 150),dtype=torch.float32)
for i in range(len(gates)+1):
    cnn =  CNN(gates, in_features[i])
    print(cnn(inputs).shape)
    if i < len(gates):   
       gates[i] = True 


0
torch.Size([32, 1])
1
torch.Size([32, 1])
2
torch.Size([32, 1])
3
torch.Size([32, 1])
4
torch.Size([32, 1])
5
torch.Size([32, 1])
