In [1]:
import torch 
from torch import nn
import numpy as np
from torchinfo import summary


device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cpu device


In [20]:
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.fcn = nn.Sequential(
            nn.Linear(300, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        self.cnn = nn.Sequential(
            nn.Conv2d(in_channels=1,
                      out_channels=4,
                      kernel_size=5,
                      stride=1,
                      padding=0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=4,
                      out_channels=8,
                      kernel_size=3,
                      stride=1,
                      padding=0),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=8,
                      out_channels=16,
                      kernel_size=2,
                      stride=1,
                      padding=0),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Flatten()
        )
        self.fused_fcn = nn.Sequential(
            nn.Linear(128+576, 128),
            nn.ELU(),
            nn.Linear(128, 64),
            nn.ELU()
        )

    def forward(self, x):
        n = 64
        x = self.flatten(x)
        image_input = x[:, :n*n].view(-1, 1, 64, 64)
        vector_input = x[:, n*n:]
        # print(image_input.shape)
        # print(vector_input.shape)
        # print(x.shape)
        vector_out = self.fcn(vector_input)
        image_out = self.cnn(image_input)
        # print(vector_out.shape)
        print(image_out.shape)
        combined = torch.cat((vector_out, image_out), dim=-1)
        # print(combined.shape)
        out = self.fused_fcn(combined)

        return 


In [21]:
model = NeuralNetwork().to(device)
print(summary(model, input_size=(1, 4396)))

torch.Size([1, 576])
Layer (type:depth-idx)                   Output Shape              Param #
NeuralNetwork                            --                        --
├─Flatten: 1-1                           [1, 4396]                 --
├─Sequential: 1-2                        [1, 128]                  --
│    └─Linear: 2-1                       [1, 256]                  77,056
│    └─ReLU: 2-2                         [1, 256]                  --
│    └─Linear: 2-3                       [1, 128]                  32,896
│    └─ReLU: 2-4                         [1, 128]                  --
├─Sequential: 1-3                        [1, 576]                  --
│    └─Conv2d: 2-5                       [1, 4, 60, 60]            104
│    └─MaxPool2d: 2-6                    [1, 4, 30, 30]            --
│    └─Conv2d: 2-7                       [1, 8, 28, 28]            296
│    └─MaxPool2d: 2-8                    [1, 8, 14, 14]            --
│    └─Conv2d: 2-9                       [1, 16, 13, 1

In [23]:
# 4096 x (64 x 64 + 300)
input = torch.rand(1, 4396, device=device)
# output = model(input)
# print(output.shape)
