In [16]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam
from torchvision import ops

from lib import detection_utils as utils
from lib.mnist_aug.mnist_augmenter import DataManager, MNISTAug
import collections

In [17]:
k = 9
H = 112
W = 112
Wp = 22
Hp = 22
b_regions = 256

threshold_p = 0.6
threshold_n = 0.3

In [18]:
dm = DataManager()
dm.load()

In [19]:
aug = MNISTAug()
x_train, y_train = aug.get_augmented(dm.x_train, dm.y_train, 10)
x_test, y_test = aug.get_augmented(dm.x_test, dm.y_test, 2)

x_train = torch.tensor(x_train, dtype=torch.float32).view((-1, 1, H, W))
x_test = torch.tensor(x_test, dtype=torch.float32).view((-1, 1, H, W))


In [20]:
class MnistDetector(nn.Module):

    def __init__(self, k):
        super().__init__()

        self.threshold_p = 0.6
        self.threshold_n = 0.3

        self.Wp = 22
        self.Hp = 22

        self.X = 28  # Width of region
        self.Y = 28

        self.b_regions = 256

        self.k = k

        self.DetectorOut = collections.namedtuple('DetectorOut', ['features', 'confidences', 'diffs', 'regions_p', 'regions_n', 'pred_bbox_p', 'pred_bbox_n', 'idx_p', 'idx_n', 'matched_bboxes', 'iou_max'])
        self.anchors_tensor = utils.generate_anchors(shape=(Wp, Hp), sizes=(.15, .45, .75),
                                        ratios=(0.5, 1, 2))  # Tensor of shape (4, k*H*W) -> cy, cy, w, h

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 16, 3),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            # nn.MaxPool2d(2,2),

            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            # nn.MaxPool2d(2,2),
        )
        self.box_regressor = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),

            nn.Conv2d(256, 5 * self.k, 1)
        )
        self.classifier = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(256, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Flatten(),
            nn.Linear(1152, 512),
            nn.ReLU(),

            nn.Linear(512, 10),
            nn.Softmax()
        )



    def forward(self, x, y_bboxes=None):
        """
        Parameters
        ---------
        x: tensor of shape (-1, C, H, W)
        y_bboxes: (optional) list of tensors of shape (4, n)
        """
        features = self.feature_extractor(x)
        bboxes = self.box_regressor(features)
        bboxes = bboxes.view((-1, 5, k, *bboxes.shape[-2:]))
        confidences = F.sigmoid(bboxes[:, 0])

        regions_p = []
        regions_n = []
        pred_bbox_p_batch = []
        pred_bbox_n_batch = []
        idx_p_batch = []
        idx_n_batch = []
        best_bbox_idx_batch = []
        iou_max_batch = []

        # If training mode, then sample positives and negatives, extract regions
        if self.training and y_bboxes is not None:
            for i_batch in range(len(x)):
                iou = utils.get_iou_map(y_bboxes[i], self.anchors_tensor)
                iou = utils.raise_bbox_iou(iou, self.threshold_p)
                iou_max, iou_argmax = torch.max(iou, 0)  # Shape (k*H*W)

                # Random sampling
                idx_p, idx_n = utils.sample_pn_indices(iou_max, self.threshold_p, self.threshold_n, self.b_regions)

                # Get off-set boxes
                pred_bbox_p, pred_bbox_n = utils.get_pred_boxes(bboxes[i, 1:], self.anchors_tensor, idx_p, idx_n)  # (4, n) (cx, cy, w, h)

                # Remove tiny boxes
                big_box_indices_p = utils.get_tiny_box_indices(pred_bbox_p, 0.05)
                big_box_indices_n = utils.get_tiny_box_indices(pred_bbox_n, 0.05)
                pred_bbox_p = pred_bbox_p[:, big_box_indices_p]
                pred_bbox_n = pred_bbox_n[:, big_box_indices_n]
                idx_p = idx_p[big_box_indices_p]
                idx_n = idx_n[big_box_indices_n]

                # Change format from (cx cy w h) to (x1 y1 x2 y2)
                pred_bbox_p = utils.centers_to_diag(pred_bbox_p)  # shape (4, p) (x1y1x2y2)
                pred_bbox_n = utils.centers_to_diag(pred_bbox_n)

                # Make record of these
                idx_p_batch.append(idx_p)
                idx_n_batch.append(idx_n)
                best_bbox_idx_batch.append(iou_argmax)
                iou_max_batch.append(iou_max)
                pred_bbox_p_batch.append(pred_bbox_p)
                pred_bbox_n_batch.append(pred_bbox_n)

                # De-Normalize - Make coordinates feature indices b/w H and W
                multiplier = torch.tensor([self.Wp, self.Hp, self.Wp, self.Hp]).view((4, 1))
                pred_bbox_p = (pred_bbox_p * multiplier).round().type(torch.int32)  # shape (4, p) (x1y1x2y2)
                pred_bbox_n = (pred_bbox_n * multiplier).round().type(torch.int32)

                # Clip boxes that are out of range
                pred_bbox_p = ops.clip_boxes_to_image(pred_bbox_p.T, (self.Hp, self.Wp)).T
                pred_bbox_n = ops.clip_boxes_to_image(pred_bbox_n.T, (self.Hp, self.Wp)).T

                # Make crops of features
                regions_batch = []
                for positive_idx in range(len(idx_p)):
                    idx = pred_bbox_p[:, positive_idx]
                    cropped = features[i_batch, :, idx[0]:idx[2]+1, idx[1]:idx[3]+1]
                    cropped = F.interpolate(cropped.view((1, *cropped.shape)), (self.X, self.Y), mode='bilinear')[0]
                    regions_batch.append(cropped)
                regions_batch = torch.stack(regions_batch)
                regions_p.append(regions_batch)

                regions_batch = []
                for negative_idx in range(len(idx_n)):
                    idx = pred_bbox_n[:, negative_idx]
                    cropped = features[i_batch, :, idx[0]:idx[2]+1, idx[1]:idx[3]+1]
                    cropped = F.interpolate(cropped.view((1, *cropped.shape)), (self.X, self.Y), mode='bilinear')[0]
                    regions_batch.append(cropped)
                regions_n.append(torch.stack(regions_batch))

        # TODO: If eval mode, then sample top 300 confidence anchors' regions
        if not self.training:
            pass

        return self.DetectorOut(
            features=features,
            confidences=confidences,
            diffs=bboxes[:, 1:],
            regions_p=regions_p,
            regions_n=regions_n,
            pred_bbox_p=pred_bbox_p_batch,
            pred_bbox_n=pred_bbox_n_batch,
            idx_p=idx_p_batch,
            idx_n=idx_n_batch,
            matched_bboxes=best_bbox_idx_batch,
            iou_max=iou_max_batch,
        )

model = MnistDetector(k)

In [21]:
optimizer = Adam(model.parameters())

In [22]:
# ==================

batch_size = 2

i = 0
start_index = i
end_index = i + batch_size

x_batch = x_train[start_index:end_index]  # TODO: maybe add light noise?
y_batch = y_train[start_index:end_index]

y_boxes = [utils.labels_to_tensor(yi, H, W) for yi in y_batch]

detector_out = model(x_batch, y_boxes)



In [24]:
# Shape: (batch, k, H, W) | ones and zeros tensor.
confidences_labels = utils.get_confidences(
    torch.stack(detector_out.iou_max),
    threshold_p,
    (batch_size, model.k, model.Hp, model.Wp)
)
confidences_labels.shape

torch.Size([2, 9, 22, 22])

In [26]:
diffs_labels = torch.stack([
    utils.get_diffs(
        y_boxes[i],
        model.anchors_tensor,
        detector_out.iou_max[i_batch],
        detector_out.matched_bboxes[i_batch],
        model.k,
        model.Hp,
        model.Wp
    )  # Shape: (4, k, H, W)
    for i_batch in range(batch_size)
])
diffs_labels.shape

TypeError: list indices must be integers or slices, not tuple

In [28]:
detector_out.matched_bboxes[0]

tensor([0, 0, 0,  ..., 2, 2, 1])

In [None]:
confidences_loss_fn = nn.BCELoss()
diffs_loss_fn = nn.L1Loss()

In [None]:
confidences_loss = confidences_loss_fn(detector_out.confidences, confidences_labels)
confidences_loss

In [None]:
diffs_loss = diffs_loss_fn(detector_out.diffs, diffs_labels)
diffs_loss

In [None]:
total_loss = confidences_loss + diffs_loss

optimizer.zero_grad()
total_loss.backward()
optimizer.step()

In [None]:
nms_boxes = []
for i_batch in range(batch_size):
    pred_boxes = torch.cat((detector_out.pred_bbox_n[i_batch].T, detector_out.pred_bbox_p[i_batch].T)).T

    confidences_batch = detector_out.confidences[i_batch].flatten()
    confidences_batch_p = confidences_batch[detector_out.idx_p[i_batch]]
    confidences_batch_n = confidences_batch[detector_out.idx_n[i_batch]]
    confidences_batch = torch.cat((confidences_batch_n, confidences_batch_p))

    nms_indices = ops.nms(pred_boxes.T, confidences_batch, 0.7)
    nms_boxes_i = pred_boxes[:, nms_indices]

    print(nms_boxes_i.shape)
    nms_boxes.append(utils.tensor_to_labels(nms_boxes_i, H, W))

In [None]:
for i_batch in range(batch_size):
    DataManager.plot_num(x_batch[i].view((H, W)), nms_boxes[i_batch])
