In [1]:
import torch

In [2]:
torch.__version__

'2.0.1'

In [3]:
from modules.models.models import SimpleCNN, init_weights
import torchsummary

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
model = SimpleCNN(encode_method="dna2vec",concat_reverse=False,
                  init_method="xavier_normal")

In [5]:
torchsummary.summary(model, (2000, 200),device="cpu",batch_size=64)
print('parameters_count:',count_parameters(model))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1            [64, 300, 2000]       1,200,300
              ReLU-2            [64, 300, 2000]               0
            Conv1d-3            [64, 300, 2000]       1,200,300
              ReLU-4            [64, 300, 2000]               0
         MaxPool1d-5             [64, 300, 100]               0
              ReLU-6             [64, 300, 100]               0
         MaxPool1d-7             [64, 300, 100]               0
              ReLU-8             [64, 300, 100]               0
 CustomConcatLayer-9             [64, 300, 200]               0
          Flatten-10                [64, 60000]               0
           Linear-11                  [64, 800]      48,000,800
             ReLU-12                  [64, 800]               0
          Dropout-13                  [64, 800]               0
           Linear-14                   

  return F.conv1d(input, weight, bias, self.stride,
  input = module(input)


In [20]:
# from thop import profile
model = SimpleCNN(encode_method="dna2vec",concat_reverse=False,
                  init_method="xavier_normal")

In [23]:
from ptflops import get_model_complexity_info
import re

with torch.cuda.device(0):
  # net = models.densenet161()
  macs, params = get_model_complexity_info(model, (2000, 200), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)
flops = eval(re.findall(r'([\d.]+)', macs)[0])*2 
# Extract the unit 
flops_unit = re.findall(r'([A-Za-z]+)', macs)[0][0] 

SimpleCNN(
  50.4 M, 100.000% Params, 4.85 GMac, 100.000% MACs, 
  (enhancer_conv1): Sequential(
    1.2 M, 2.381% Params, 2.4 GMac, 49.492% MACs, 
    (0): Conv1d(1.2 M, 2.381% Params, 2.4 GMac, 49.480% MACs, 100, 300, kernel_size=(40,), stride=(1,), padding=same)
    (1): ReLU(0, 0.000% Params, 600.0 KMac, 0.012% MACs, )
  )
  (enhancer_maxpool1): Sequential(
    0, 0.000% Params, 630.0 KMac, 0.013% MACs, 
    (0): MaxPool1d(0, 0.000% Params, 600.0 KMac, 0.012% MACs, kernel_size=20, stride=20, padding=0, dilation=1, ceil_mode=False)
    (1): ReLU(0, 0.000% Params, 30.0 KMac, 0.001% MACs, )
  )
  (promoter_conv1): Sequential(
    1.2 M, 2.381% Params, 2.4 GMac, 49.492% MACs, 
    (0): Conv1d(1.2 M, 2.381% Params, 2.4 GMac, 49.480% MACs, 100, 300, kernel_size=(40,), stride=(1,), padding=same)
    (1): ReLU(0, 0.000% Params, 600.0 KMac, 0.012% MACs, )
  )
  (promoter_maxpool1): Sequential(
    0, 0.000% Params, 630.0 KMac, 0.013% MACs, 
    (0): MaxPool1d(0, 0.000% Params, 600.0 KMac, 0

In [24]:
print('Computational complexity: {:<8}'.format(macs)) 
print('Computational complexity: {} {}Flops'.format(flops, flops_unit)) 
print('Number of parameters: {:<8}'.format(params))

Computational complexity: 4.85 GMac
Computational complexity: 9.7 GFlops
Number of parameters: 50.4 M  
