In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

from lib import detection_utils as utils
from lib.mnist_aug.mnist_augmenter import DataManager, MNISTAug
import collections

In [2]:
k = 9
H = 112
W = 112
Wp = 22
Hp = 22
b_regions = 256

threshold_p = 0.6
threshold_n = 0.3

In [3]:
dm = DataManager()
dm.load()

In [4]:
aug = MNISTAug()
x_train, y_train = aug.get_augmented(dm.x_train, dm.y_train, 10)
x_test, y_test = aug.get_augmented(dm.x_test, dm.y_test, 2)

x_train = torch.tensor(x_train, dtype=torch.float32).view((-1, 1, H, W))
x_test = torch.tensor(x_test, dtype=torch.float32).view((-1, 1, H, W))


In [6]:
class MnistDetector(nn.Module):

    def __init__(self, k):
        super().__init__()

        self.threshold_p = 0.6
        self.threshold_n = 0.3

        self.Wp = 22
        self.Hp = 22

        self.X = 64  # Width of region
        self.Y = 64

        self.b_regions = 256

        self.k = k

        self.DetectorOut = collections.namedtuple('DetectorOut', 'features confidences diffs regions_p regions_n idx_p idx_n matched_bboxes')
        self.anchors_tensor = utils.generate_anchors(shape=(Wp, Hp), sizes=(.15, .45, .75),
                                        ratios=(0.5, 1, 2))  # Tensor of shape (4, k*H*W) -> cy, cy, w, h

        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 16, 3),
            nn.ReLU(),
            nn.Conv2d(16, 16, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(16, 32, 3),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2,2),

            nn.Conv2d(32, 64, 3),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding=1),
            nn.ReLU(),
            # nn.MaxPool2d(2,2),

            nn.Conv2d(64, 128, 3),
            nn.ReLU(),
            nn.Conv2d(128, 128, 3, padding=1),
            nn.ReLU(),
            # nn.MaxPool2d(2,2),
        )
        self.box_regressor = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),

            nn.Conv2d(256, 5 * self.k, 1)
        )



    def forward(self, x, y_bboxes=None):
        """
        Parameters
        ---------
        x: tensor of shape (-1, C, H, W)
        bboxes: (optional) list of tensors of shape (4, n)
        """
        features = self.feature_extractor(x)
        bboxes = self.box_regressor(features)
        bboxes = bboxes.view((-1, 5, k, *bboxes.shape[-2:]))

        regions_p = []
        regions_n = []
        idx_p_batch = []
        idx_n_batch = []
        best_bbox_idx_batch = []

        # If training mode, then sample positives and negatives, extract regions
        if self.training and y_bboxes is not None:
            for i_batch in range(len(x)):
                iou = utils.get_iou_map(y_bboxes[i], self.anchors_tensor)
                iou = utils.raise_bbox_iou(iou, self.threshold_p)
                iou_max, iou_argmax = torch.max(iou, 0)  # Shape (k*H*W)
                idx_p, idx_n = utils.sample_pn_indices(iou_max, self.threshold_p, self.threshold_n, self.b_regions)
                pred_bbox_p, pred_bbox_n = utils.get_pred_boxes(bboxes[i, 1:], self.anchors_tensor, idx_p, idx_n)  # (4, n) (cx, cy, w, h)
                pred_bbox_p = utils.centers_to_diag(pred_bbox_p)  # shape (4, p) (x1y1x2y2)
                pred_bbox_n = utils.centers_to_diag(pred_bbox_n)

                idx_p_batch.append(idx_p)
                idx_n_batch.append(idx_n)
                best_bbox_idx_batch.append(iou_argmax)

                # Make coordinates feature indices b/w H and W
                multiplier = torch.tensor([self.Wp, self.Hp, self.Wp, self.Hp]).view((4, 1))
                pred_bbox_p = (pred_bbox_p * multiplier).type(torch.int32)  # shape (4, p) (x1y1x2y2)
                pred_bbox_n = (pred_bbox_n * multiplier).type(torch.int32)

                # Make crops of features
                regions_batch = []
                for positive_idx in range(len(idx_p)):
                    idx = pred_bbox_p[positive_idx]
                    cropped = features[:, idx[0]:idx[2]+1, idx[1]:idx[3]+1]
                    cropped = F.interpolate(cropped, (self.H, self.W), mode='bilinear')
                    regions_batch.append(cropped)
                regions_p.append(torch.tensor(regions_batch))

                regions_batch = []
                for negative_idx in range(len(idx_n)):
                    idx = pred_bbox_n[negative_idx]
                    cropped = features[:, idx[0]:idx[2]+1, idx[1]:idx[3]+1]
                    cropped = F.interpolate(cropped, (self.H, self.W), mode='bilinear')
                    regions_batch.append(cropped)
                regions_n.append(torch.tensor(regions_batch))

        # TODO: If eval mode, then sample top 300 confidence anchors' regions
        if not self.training:
            pass

        return self.DetectorOut(
            features=features,
            confidences=bboxes[:, 0] if bboxes is not None else None,
            diffs=bboxes[:, 1:] if bboxes is not None else None,
            regions_p=regions_p,
            regions_n=regions_n,
            idx_p=idx_p_batch,
            idx_n=idx_n_batch,
            matched_bboxes=best_bbox_idx_batch)

In [7]:
model = MnistDetector(k)

In [8]:
optimizer = Adam(model.parameters())

In [12]:
def process_y_batch(y):
    ys = []
    for yi in y:
        yi_ = utils.labels_to_tensor(yi, H, W)
        iou = utils.get_iou_map(yi_, anchors_tensor)
        iou = utils.raise_bbox_iou(iou, threshold_p)
        iou_max, iou_argmax = torch.max(iou, 0)  # Shape (k*H*W)

        confidences = utils.get_confidences(iou_max, threshold_p, (k, Hp, Wp))
        diffs = utils.get_diffs(yi_, anchors_tensor, iou_max, iou_argmax, k, Hp, Wp)

        ys.append(torch.stack((confidences, *diffs)))

    return torch.stack(ys)


epochs = 1
batch_size = 2
for epoch in range(epochs):

    for i in range(0, len(x_train), batch_size):
        start_index = i
        end_index = i+batch_size

        x_batch = x_train[start_index:end_index]
        y_batch = process_y_batch(y_train[start_index:end_index])

        y_hat = model(x_batch)

        indices = []
        for j in range(batch_size):
            idx_p, idx_n = utils.sample_pn_indices(y_hat[i][0].flatten(0), threshold_p, threshold_n, b_regions)
            print(idx_p.shape)
            break
        break


torch.Size([0])


In [10]:
# ==================

i = 0
start_index = i
end_index = i + batch_size

x_batch = x_train[start_index:end_index]
y_batch = y_train[start_index:end_index]

# Preprocessing y_batch -----------
confidences_batch = []
diffs_batch = []

# For each item in the batch:
yi = y_batch[0]

yi_ = utils.labels_to_tensor(yi, H, W)
iou = utils.get_iou_map(yi_, anchors_tensor)
iou = utils.raise_bbox_iou(iou, threshold_p)
iou_max, iou_argmax = torch.max(iou, 0)  # Shape (k*H*W)

confidences = utils.get_confidences(iou_max, threshold_p, (k, Hp, Wp))
diffs = utils.get_diffs(yi_, anchors_tensor, iou_max, iou_argmax, k, Hp, Wp)

confidences_batch.append(confidences)
diffs_batch.append(diffs)

confidences_batch = torch.stack(confidences_batch)
diffs_batch = torch.stack(diffs_batch)

# return confidences_batch, diffs_batch  # -------- END PROCESSING ----------

y_hat = model(x_batch)

indices = []
for j in range(batch_size):
    idx_p, idx_n = utils.sample_pn_indices(y_hat[i][0].flatten(0), threshold_p, threshold_n, b_regions)
    print(idx_p.shape)
    break

In [11]:
model.training


True