In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from lib import detection_utils as utils
from lib.mnist_aug.mnist_augmenter import DataManager, MNISTAug

In [2]:
k = 9
H = 112
W = 112
Wp = 22
Hp = 22
b_regions = 256

threshold_p = 0.6
threshold_n = 0.3

In [3]:
dm = DataManager()
dm.load()

In [4]:
aug = MNISTAug()
x_train, y_train = aug.get_augmented(dm.x_train, dm.y_train, 10)
x_test, y_test = aug.get_augmented(dm.x_test, dm.y_test, 2)

x_train = torch.tensor(x_train, dtype=torch.float32).view((-1, 1, H, W))
x_test = torch.tensor(x_test, dtype=torch.float32).view((-1, 1, H, W))

In [5]:
anchors_tensor = utils.generate_anchors(shape=(Wp, Hp), sizes=(.15, .45, .75),
                                        ratios=(0.5, 1, 2))  # Tensor of shape (4, k*H*W) -> cy, cy, w, h


In [6]:
class MnistDetector(nn.Module):

    def __init__(self, k):
        super().__init__()
        self.k = k
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 16, 3),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            # nn.MaxPool2d(2,2),

            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            # nn.MaxPool2d(2,2),

            nn.Conv2d(128, 5 * self.k, 1)
        )


    def forward(self, x):
        features = self.feature_extractor(x)
        features = features.view((-1, 5, k, *features.shape[-2:]))
        return features

In [7]:
model = MnistDetector(k)

In [8]:
optimizer = Adam(model.parameters())

In [9]:
def process_y_batch(y):
    ys = []
    for yi in y:
        yi_ = utils.labels_to_tensor(yi, H, W)
        iou = utils.get_iou_map(yi_, anchors_tensor)
        iou = utils.raise_bbox_iou(iou, threshold_p)
        iou_max, iou_argmax = torch.max(iou, 0)  # Shape (k*H*W)

        confidences = utils.get_confidences(iou_max, threshold_p, (k, Hp, Wp))
        diffs = utils.get_diffs(yi_, anchors_tensor, iou_max, iou_argmax, k, Hp, Wp)

        ys.append(torch.stack((confidences, *diffs)))

    return torch.stack(ys)

In [10]:
epochs = 1
batch_size = 2
for epoch in range(epochs):

    for i in range(0, len(x_train), batch_size):
        start_index = i
        end_index = i+batch_size

        x_batch = x_train[start_index:end_index]
        y_batch = process_y_batch(y_train[start_index:end_index])

        y_hat = model(x_batch)

        indices = []
        for j in range(batch_size):
            idx_p, idx_n = utils.sample_pn_indices(y_hat[i][0].flatten(0), threshold_p, threshold_n, b_regions)
            print(idx_p.shape)
            break
        break


torch.Size([0])
