In [None]:
import torch
from torch import nn
import torch.nn.functional as F

from lib.mnist_aug.mnist_augmenter import DataManager, MNISTAug

In [None]:
aug = MNISTAug()
dm = DataManager()
dm.load()

In [None]:
aug.max_objects = 5
aug.min_objects = 3

x_train, y_train = aug.get_augmented(dm.x_train, dm.y_train, 2, get_captions=True)
x_test, y_test = aug.get_augmented(dm.x_train, dm.y_train, 2, get_captions=True)

DataManager.plot_num(x_train[0], y_train[0])

In [None]:
W = 112
H = 112

In [None]:
class DenseCapModel(nn.Module):

    def __init__(self):
        super().__init__()

        self.k = 9
        self.X = 64
        self.Y = 64
        self.V = 100  # TODO: Set vocab size

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.rpn = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),

            nn.Conv2d(512, 5 * self.k, kernel_size=1, stride=1, padding=0)

            # sigmoid / tanh on first k, relu on the rest
        )

        self.recognition = nn.Sequential(
            nn.Linear(self.X * self.Y * 512, 4096),
            nn.ReLU(),

            nn.Linear(4096, 4096),
            nn.ReLU()
        )

        self.second_rpn = nn.Linear(4096, 4096) # sigmoid / tanh on first, relu on the rest

        self.rnn = nn.Sequential(
            nn.LSTM(512, 512),  # TODO: This might be wrong

            nn.Linear(512, self.V),
            nn.Softmax()
        )

    def forward(self, x):
        feature_map = self.feature_extractor(x)

        region_proposals = self.rpn(feature_map)

        # TODO: Apply sigmoid and relu on these
        # TODO: Project the regions to features
        # TODO: Slice the regions in the features
        # TODO: Apply bilinear interpolation on the slices
        # TODO: For each region, call recognize_and_generate function

    def recognize_and_generate(self, feature_map):
        features = torch.flatten(feature_map)
        # TODO: A relu here maybe?
        
        recognized = self.recognition(features)
        
        offsets = self.second_rpn(recognized)
        # TODO: Apply activations
        
        self.rnn(recognized)
        
