In [1]:
import torch
import torch.nn as nn 
import torch.functional as F
import torchinfo
import hydra
import math
from hydra import compose,initialize
from omegaconf import OmegaConf
from model_cnn import MRNA_CNN_CONCAT

In [25]:
class MRNA_CNN_CONCAT(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg.model
        self.layer1 = nn.Sequential(
            nn.Conv1d(
                in_channels=4,
                out_channels=self.cfg.conv1_ch,
                kernel_size=self.cfg.conv1_kernel,
                stride=self.cfg.stride1,
                dilation=self.cfg.dilation1,
            ),
            nn.ReLU(),
            nn.BatchNorm1d(
                num_features=self.cfg.conv1_ch
            ),  # num_features = channel num (Batch,Channel,Length)
            nn.Dropout(p=0.02),
            nn.MaxPool1d(kernel_size=2, stride=1),
        )
        self.layer1_out_size=math.floor((self.cfg.max_len-self.cfg.dilation1*(self.cfg.conv1_kernel-1)-1)/self.cfg.stride1)


        self.layer2 = nn.Sequential(
            nn.Conv1d(
                in_channels=self.cfg.conv1_ch,
                out_channels=self.cfg.conv2_ch,
                kernel_size=self.cfg.conv2_kernel,
                padding="same",
                stride=self.cfg.stride2,
                dilation=self.cfg.dilation2,
            ),
            nn.ReLU(),
            nn.BatchNorm1d(num_features=self.cfg.conv2_ch),
            nn.Dropout(p=0.02),
            # nn.MaxPool1d(kernel_size=2, stride=2),
        )

        self.layer3 = nn.Sequential(
            nn.Conv1d(
                in_channels=self.cfg.conv2_ch,
                out_channels=self.cfg.conv3_ch,
                kernel_size=self.cfg.conv3_kernel,
                padding="same",
                stride=self.cfg.stride3,
                dilation=1,
            ),
            nn.ReLU(),
            nn.BatchNorm1d(num_features=self.cfg.conv3_ch),
            nn.Dropout(p=0.02),
        )

        self.linear_input_dim = self.layer1_out_size*self.cfg.conv3_ch+self.cfg.feat_dim

        self.layer4 = nn.Sequential(
            nn.Linear(
                in_features=self.linear_input_dim, out_features=self.cfg.linear1_dim
            ),
            nn.ReLU(),
            nn.BatchNorm1d(num_features=self.cfg.linear1_dim),
            nn.Dropout(p=0.02),
        )

        self.layer5 = nn.Sequential(
            nn.Linear(
                in_features=self.cfg.linear1_dim, out_features=self.cfg.linear2_dim
            ),
            nn.ReLU(),
            nn.BatchNorm1d(num_features=self.cfg.linear2_dim),
            nn.Dropout(p=0.02),
        )

        self.layer_out = nn.Sequential(
            nn.Linear(in_features=self.cfg.linear2_dim, out_features=1)
        )

        self.flatter = nn.Flatten()

    def forward(self, x_seq, x_feat):
        print(f"input:{x_seq.size()}")
        x_seq = self.layer1(x_seq)  # x_seq:(bs,ch,seq_len)->(bs,c)
        print(f"layer1:{x_seq.size()}")
        x_seq = self.layer2(x_seq)  # x_seq:
        print(f"layer2:{x_seq.size()}")
        x_seq = self.layer3(x_seq)
        print(f"layer3:{x_seq.size()}")

        x_seq = self.flatter(x_seq)
        print(f"flatten:{x_seq.size()}")
        x_concat = torch.cat((x_seq, x_feat), 1)  # concat (seq,feat)
        print(f"concat:{x_concat.size()}")
        x_concat = self.layer4(x_concat)
        x_concat = self.layer5(x_concat)
        out = self.layer_out(x_concat)

        return out


In [2]:
hydra.core.global_hydra.GlobalHydra.instance().clear()
initialize(config_path="configs")
cfg=compose(config_name="cnn_concat")
model=MRNA_CNN_CONCAT(cfg)

The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  initialize(config_path="configs")


In [4]:
input1=torch.rand((32,4,500))
input2=torch.rand((32,66))
out=model(input1,input2)
print(out.size())

torch.Size([32, 1])


In [11]:
torchinfo.summary(model,input_data=(input1,input2))

RuntimeError: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: []