In [3]:
import torch
import torch.nn as nn
import torchvision


In [2]:
def apply_regressions_pred_to_anchors_or_proposals(box_trans, anchor_or_prop):
        box_trans = box_trans.reshape(
            box_trans.size(0), -1, 4
        )
        w = anchor_or_prop[:, 2] - anchor_or_prop[:, 0]
        h = anchor_or_prop[:, 3] - anchor_or_prop[:, 1]
        center_x = anchor_or_prop[:, 0] + 0.5*w
        center_y = anchor_or_prop[:, 1] + 0.5*h

        dx = box_trans[..., 0]
        dy = box_trans[..., 1]
        dw = box_trans[..., 2]
        dh = box_trans[..., 3]

        pred_center_x = dx*w[:, None] + center_x[:, None]
        pred_w = torch.exp(dw) + w[:, None]
        pred_center_y = dy*h[:, None] + center_y[:, None]
        pred_h = torch.exp(dh) + h[:, None]

        pred_box_x1 = pred_center_x - 0.5*pred_w
        pred_box_y1 = pred_center_y - 0.5*pred_h
        pred_box_x2 = pred_center_x + 0.5*pred_w
        pred_box_y2 = pred_center_y + 0.5*pred_h

        pred_boxes = torch.stack((
            pred_box_x1,
            pred_box_y1,
            pred_box_x2,
            pred_box_y2
        ), dim=2)

        return pred_boxes


def get_iou(box1, box2):
     
    area1 = (box1[:, 2] - box1[:, 0]) * (box1[:, 3] - box1[:, 1])
    area2 = (box2[:, 2] - box2[:, 0]) * (box2[:, 3] - box2[:, 1])

    x_left = torch.max(box1[:, None, 0], box2[:, None, 0])
    y_top = torch.max(box1[:, None, 1], box2[:, None, 1])

    x_right = torch.min(box1[:, None, 2], box2[:, None, 2])
    y_bottom = torch.min(box1[:, None, 3], box2[:, None, 3])

    intersection_area = (x_right - x_left).clamp(min=0) * (y_bottom - y_top).clamp(min=0)
    union_area = area1[:, None] + area2 - intersection_area

    return intersection_area/union_area


def clap_box_to_img_size(boxes, image_shape):
    boxes_x1 = boxes[..., 0]
    boxes_y1 = boxes[..., 1]
    boxes_x2 = boxes[..., 2]
    boxes_y2 = boxes[..., 3]
    hieght, width = image_shape[-2:]

    boxes_x1 = boxes_x1.clamp(0, width)
    boxes_x2 = boxes_x2.clamp(0, width)
    boxes_y1 = boxes_y1.clamp(0, hieght)
    boxes_y2 = boxes_y2.clamp(0, hieght)

    boxes = torch.cat((
        boxes_x1[..., None],
        boxes_y1[..., None],
        boxes_x2[..., None],
        boxes_y2[..., None]
    ), dim=-1)

    return boxes


def boxes_to_transformation_targets(ground_truth_boxes, anchor_for_propoosals):
    width = anchor_for_propoosals[:, 2] - anchor_for_propoosals[:, 0]
    height = anchor_for_propoosals[:, 3] - anchor_for_propoosals[:, 1]
    center_x = anchor_for_propoosals[:, 0] + 0.5*width
    center_y = anchor_for_propoosals[:, 1] + 0.5*height

    gt_width = ground_truth_boxes[:, 2] - ground_truth_boxes[:, 0]
    gt_height = ground_truth_boxes[:, 3] - ground_truth_boxes[:, 1]
    gt_center_x = ground_truth_boxes[:, 0] + 0.5*gt_width
    gt_center_y = ground_truth_boxes[:, 1] + 0.5*gt_height

    target_dx = (gt_center_x - center_x)/width
    target_dy = (gt_center_y - center_y)/height
    target_dw = torch.log(gt_width/width)
    target_dh = torch.log(gt_height/height)

    regression_targets = torch.stack((
        target_dx,
        target_dy,
        target_dw,
        target_dh
    ), dim=1)

    return regression_targets


def transform_boxes_to_original_size(boxes, new_size, original_size):
    r"""
    Boxes are for resized image (min_size=600, max_size=1000).
    This method converts the boxes to whatever dimensions
    the image was before resizing
    :param boxes:
    :param new_size:
    :param original_size:
    :return:
    """
    ratios = [
        torch.tensor(s_orig, dtype=torch.float32, device=boxes.device)
        / torch.tensor(s, dtype=torch.float32, device=boxes.device)
        for s, s_orig in zip(new_size, original_size)
    ]
    ratio_height, ratio_width = ratios
    xmin, ymin, xmax, ymax = boxes.unbind(1)
    xmin = xmin * ratio_width
    xmax = xmax * ratio_width
    ymin = ymin * ratio_height
    ymax = ymax * ratio_height
    return torch.stack((xmin, ymin, xmax, ymax), dim=1)



def sample_positive_negative(labels, positive_count, total_count):
    positive = torch.where(labels>=1)[0]
    negative = torch.where(labels==0)[0]

    num_positive = positive_count
    num_positive = torch.min(positive.numel(), num_positive)

    num_negative = total_count - num_positive
    num_negative = torch.min(negative.numel(), num_negative)

    perm_pos_index = torch.randperm(
        positive.numel(), device=positive.device
    )[:num_positive]
    perm_neg_index = torch.randperm(
        negative.numel(), device=negative.device
    )[:num_negative]

    pos_index = positive[perm_pos_index]
    neg_index = negative[perm_neg_index]

    sampled_pos_index = torch.zeros_like(labels, dtype=torch.bool)
    sampled_neg_index = torch.zeros_like(labels, dtype=torch.bool)
    sampled_pos_index[pos_index] = True
    sampled_neg_index[neg_index] = True

    return sampled_neg_index, sampled_pos_index


class RegionProposalNetwork(nn.Module):
    def __init__(self, in_channels = 512) -> None:
        super(RegionProposalNetwork, self).__init__()
        self.scales = [128, 256, 512]
        self.aspect_ratio = [0.5, 1, 2]
        self.num_anchors = len(self.scales)*len(self.aspect_ratio)


        self.rpn_conv = nn.Conv2d(in_channels,
                                  in_channels,
                                  kernel_size=3,
                                  stride=1,
                                  padding=1
                            )
        self.clf_layer = nn.Conv2d(in_channels, 
                                   self.num_anchors,
                                   kernel_size=1,
                                   stride=1)
        self.bbox_reg = nn.Conv2d(in_channels,
                                  self.num_anchors*4,
                                  kernel_size=1,
                                  stride=1)
        
    def generate_anchors(self, image, feat):
        grid_h, grid_w = feat.shape[-2:]
        image_h, image_w = image.shape[-2:]
        stride_h = torch.tensor(image_h//grid_h,
                                dtype=torch.int64,
                                device=feat.device)
        stride_w = torch.tensor(image_w//grid_w,
                                dtype=torch.int64,
                                device=feat.device)
        scales = torch.tensor(self.scales, 
                              dtype=feat.dtype,
                              device=feat.device)
        aspect_rotation = torch.tensor(self.aspect_ratio,
                                       dtype=feat.dtype,
                                       device=feat.device)
        
        h_rations = torch.sqrt(aspect_rotation)
        w_rations = 1/h_rations

        ws = (w_rations[:, None] * scales[None, :]).view(-1)
        hs = (h_rations[:, None] * scales[None, :]).view(-1)
        base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1)/2
        base_anchors = base_anchors.round()

        shift_x = torch.arange(0, grid_w,
                               dtype=torch.int64,
                               device=feat.device) * stride_w
        shift_y = torch.arange(0, grid_h,
                               dtype=torch.int64,
                               device=feat.device) * stride_h
        shift_y, shift_x = torch.meshgrid(shift_y, shift_x, 
                                          indexing='ij')
        
        shift_x = shift_x.reshape(-1)
        shift_y = shift_y.reshape(-1)
        shifts = torch.stack((shift_x,
                              shift_y,
                              shift_x,
                              shift_y), dim=1)
        anchors = (shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4))
        anchors = anchors.reshape(-1, 4)
        return anchors
    

    def apply_regressions_pred_to_anchors_or_proposals(box_trans, anchor_or_prop):
        box_trans = box_trans.reshape(
            box_trans.size(0), -1, 4
        )
        w = anchor_or_prop[:, 2] - anchor_or_prop[:, 0]
        h = anchor_or_prop[:, 3] - anchor_or_prop[:, 1]
        center_x = anchor_or_prop[:, 0] + 0.5*w
        center_y = anchor_or_prop[:, 1] + 0.5*h

        dx = box_trans[..., 0]
        dy = box_trans[..., 1]
        dw = box_trans[..., 2]
        dh = box_trans[..., 3]

        pred_center_x = dx*w[:, None] + center_x[:, None]
        pred_w = torch.exp(dw) + w[:, None]
        pred_center_y = dy*h[:, None] + center_y[:, None]
        pred_h = torch.exp(dh) + h[:, None]

        pred_box_x1 = pred_center_x - 0.5*pred_w
        pred_box_y1 = pred_center_y - 0.5*pred_h
        pred_box_x2 = pred_center_x + 0.5*pred_w
        pred_box_y2 = pred_center_y + 0.5*pred_h

        pred_boxes = torch.stack((
            pred_box_x1,
            pred_box_y1,
            pred_box_x2,
            pred_box_y2
        ), dim=2)

        return pred_boxes
    

    def assign_target_to_anchors(self, anchors, gt_boxes):
        iou_matrix = get_iou(gt_boxes, anchors)
        best_match_iou, index = iou_matrix.max(dim=0)

        best_match_iou_index_copy = index.clone()

        below_low = best_match_iou < 0.3
        between = (best_match_iou > 0.3) & (best_match_iou < 0.7)

        index[below_low] = -1
        index[between] = -2

        best_anchor_for_gt, _ = iou_matrix.max(dim=1)
        gt_pred_pairs_with_highst_iou = torch.where(iou_matrix==best_anchor_for_gt[:, None])
        pred_inds_to_update = gt_pred_pairs_with_highst_iou[1]
        best_match_iou[pred_inds_to_update] = best_match_iou_index_copy[pred_inds_to_update]

        matched_gt_boxes = gt_boxes[best_match_iou.clamp(min=0)]
        labels = best_match_iou >= 0
        labels = labels.to(dtype=torch.float32)

        backgroud_anchors = best_match_iou == -1
        labels[backgroud_anchors] = 0.0

        ignored_anchores = best_match_iou == -2
        labels[ignored_anchores] = -1.0
        
        return labels, matched_gt_boxes

    

    def filter_propoosals(self, propoosels, clf_score, img_shape):
        clf_score = clf_score.reshape(-1)
        clf_score = torch.sigmoid(clf_score)
        _, top_in_index = clf_score.topk(1000)
        clf_score = clf_score[top_in_index]
        propoosels = propoosels[top_in_index]

        propoosels = clap_box_to_img_size(propoosels, img_shape)

        keep_mask = torch.zeros_like(clf_score, dtype=bool)
        keep_indices = torch.ops.torchvision.nms(propoosels,
                                                 clf_score, 0.7)
        
        post_nms_keep_indecies = keep_indices[
                                clf_score[keep_indices].sort(descending=True)[1]
                                            ]
        
        propoosels = propoosels[post_nms_keep_indecies[:2000]]
        clf_score = clf_score[post_nms_keep_indecies[:2000]]

        return propoosels, clf_score
         


    
    def forward(self, image, feat, target):
        rpn_feat = nn.ReLU()(self.rpn_conv(feat))
        cls = self.clf_layer(rpn_feat)
        box_trans = self.bbox_reg(rpn_feat)
        anchors = self.generate_anchors(image, feat)
        number_of_anchors_per_location = cls.size(1)
        cls = cls.permute(0, 2, 3, 1)
        cls = cls.reshape(-1, 1)
        box_trans = box_trans.view(
            box_trans.size(0),
            number_of_anchors_per_location,
            4,
            rpn_feat.shape[-1],
            rpn_feat.shape[-2]
        )
        box_trans = box_trans.permute(0, 3, 4, 1, 2)
        box_trans = box_trans.reshape(-1, 4)

        propoosels = apply_regressions_pred_to_anchors_or_proposals(
            box_trans.detach().reshape(-1, 1, 4),
            anchors)
        propoosels = propoosels.reshape(propoosels.size(0), 4)
        propoosels, scores = self.filter_propoosals(propoosels,
                                                    cls.detach(),
                                                    image.shape)
        rpn_output = {
            'propoosels': propoosels,
            'scores': scores
        }

        if not self.training or target is None:
            return rpn_output
        else:
            labels_for_anchors, matched_gt_boxes_for_anchors = self.assign_target_to_anchors(
                anchors,
                target['bboxes'][0]
            )

            regression_targets = boxes_to_transformation_targets(
                matched_gt_boxes_for_anchors, anchors
            )
            sampled_neg_mask, sampled_pos_mask = sample_positive_negative(
                labels_for_anchors, positive_count=128, total_count=256
                )
            
            sampled_index = torch.where(sampled_pos_mask | sampled_neg_mask)[0]
            localization_loss = (
                nn.functional.smooth_l1_loss(
                    box_trans[sampled_pos_mask],
                    regression_targets[sampled_pos_mask],
                    beta = 1/9,
                    reuction = 'sum'
                )/(sampled_index.numel())
            )
            cls_loss = nn.functional.binary_cross_entropy_with_logits(
                cls[sampled_index].flatten(),
                labels_for_anchors[sampled_index].flatten()
            )
            rpn_output['cls_loss'] = cls_loss
            rpn_output['localization_loss'] = localization_loss

            return rpn_output





In [None]:
class ROIHead(nn.Module):
    def __init__(self, num_classes=2, in_channels = 512):
        super(ROIHead, self).__init__()
        self.num_classes = num_classes
        self.pool_size = 7
        self.fc_inner_dim = 1024

        self.fc6 = nn.Linear(in_channels*self.pool_size*self.pool_size,
                              self.fc_inner_dim)
        self.fc7 = nn.Linear(self.fc_inner_dim,
                             self.fc_inner_dim)
        self.fccls = nn.Linear(self.fc_inner_dim,
                               self.num_classes)
        self.reg_layer = nn.Linear(self.fc_inner_dim,
                                   self.num_classes*4)
        

    def assign_target_to_propoosels(self, propoosels, gt_boxes, gt_labels):
        iou_metrix = get_iou(gt_boxes, propoosels)
        best_match_iou, best_match_gt_index = iou_metrix.max(dim=0)
        below_low = best_match_iou < 0.5

        best_match_gt_index[below_low] = -1
        matched_gt_boxes_for_propoosels = gt_boxes[best_match_gt_index.clamp(min=0)]

        labels = gt_labels[best_match_gt_index.clamp(min=0)]
        labels = labels.to(dtype=torch.int64)

        background_propoosels = best_match_gt_index == -1
        labels[background_propoosels] = 0

        return labels, matched_gt_boxes_for_propoosels
    


    def filter_predictions(self, pred_box, pred_labels, pred_scores):
        keep = torch.where(pred_scores > 0.05)[0]
        pred_box, pred_scores, pred_labels = pred_box[keep], pred_scores[keep], pred_labels[keep]

        min_size = 1
        ws, hs = pred_box[:, 2] - pred_box[:, 0], pred_box[:, 3] - pred_box[:, 1]
        keep = (ws >= min_size) & (hs >= min_size)
        keep = torch.where(keep)[0]
        pred_box, pred_scores, pred_labels = pred_box[keep], pred_scores[keep], pred_labels[keep]

        keep_mask = torch.zeros_like(pred_scores, dtype=torch.bool)
        for class_id in torch.unique(pred_labels):
            curr_indices = torch.where(pred_labels == class_id)[0]
            curr_keep_indices = torchvision.ops.nms(
                pred_box[curr_indices],
                pred_scores[curr_indices],
                0.5
            )
            keep_mask[curr_indices[curr_keep_indices]] = True
        keep_indecies = torch.where(keep_mask)[0]
        post_nms_keep_indecies = keep_indecies[pred_scores[keep_indecies].sort(
            descending=True
        )[1]]
        keep = post_nms_keep_indecies[:100]
        pred_box, pred_scores, pred_labels = pred_box[keep], pred_scores[keep], pred_labels[keep]

        return pred_box, pred_scores, pred_labels


    


    def forward(self, feat, propoosals, img_shape, target):
        if self.training and target is not None:
            gt_boxes = target['bbox'][0]
            gt_labels = target['labels'][0]
            labels, matched_gt_boxes_for_propoosels = self.assign_target_to_propoosels(
                propoosals, gt_boxes, gt_labels
            )
            sampled_neg_indx_mask, sampled_pos_indx_mask = sample_positive_negative(
                labels, positive_count=32, total_count=128
            )
            sampled_indx = torch.where(sampled_pos_indx_mask | sampled_neg_indx_mask)[0]
            propoosals = propoosals[sampled_indx]
            labels = labels[sampled_indx]

            matched_gt_boxes_for_propoosels = matched_gt_boxes_for_propoosels[sampled_indx]
            regrassion_targets = boxes_to_transformation_targets(
                matched_gt_boxes_for_propoosels, propoosals
            )

        spatial_scale = 0.0625

        propoosal_roi_pool_feats = torchvision.ops_roi_pool(
            feat,
            [propoosals],
            output_size = self.pool_size,
            spatial_scale = spatial_scale
        
        )
        propoosal_roi_pool_feats = propoosal_roi_pool_feats.flatten(start_dim=1)
        box_fc6 = torch.nn.functional.relu(self.fc6(propoosal_roi_pool_feats))
        box_fc7 = torch.nn.functiomal.relu(self.fc7(box_fc6))
        cls_score = self.fccls(box_fc7)
        box_transform_pred = self.reg_layer(box_fc7)

        num_boxes, num_classes = cls_score.shape
        box_transform_pred = box_transform_pred.reshape(num_boxes, num_classes, 4)

        frcnn_output = {}
        if self.training and target is not None:
            classification_loss = nn.functional.cross_entropy(
                cls_score,
                labels
            )

            fg_propoosal_indx = torch.where(labels > 0)[0]
            fg_class_labels = labels[fg_propoosal_indx]
            localization_loss = torch.nn.functional.smooth_l1_loss(
                box_transform_pred[fg_propoosal_indx, fg_class_labels],
                regrassion_targets[fg_propoosal_indx],
                beta=1/9,
                reduction = 'sum'
            )
            localization_loss = localization_loss/labels.numel()
            frcnn_output['frcnn-classification-loss'] = classification_loss
            frcnn_output['frcnn-localiztion-loss'] = localization_loss
            return frcnn_output
        
        else:
            pred_boxes = apply_regressions_pred_to_anchors_or_proposals(
                box_transform_pred,
                propoosals
            )
            pred_scores = torch.nn.functional.softmax(cls_score, dim=-1)

            pred_boxes = clap_box_to_img_size(pred_boxes, img_shape)
            pred_labels = torch.arange(num_classes, device=cls_score.device)
            pred_labels = pred_labels.view(1, -1).expand_as(pred_scores)

            pred_boxes = pred_boxes[:, 1:]
            pred_labels = pred_labels[:, 1:]
            pred_scores = pred_scores[:, 1:]

            pred_boxes = pred_boxes.reshape(-1, 4)
            pred_labels = pred_labels.reshape(-1)
            pred_scores = pred_scores.reshape(-1)

            pred_boxes, pred_labels, pred_scores = self.filter_predictions(
                pred_boxes,
                pred_labels,
                pred_scores
            )
            frcnn_output['bboxes'] = pred_boxes
            frcnn_output['scores'] = pred_scores
            frcnn_output['labels'] = pred_labels

            return frcnn_output




In [None]:
class FasterRCNN(nn.Module):
    def __init__(self, num_classes=2) -> None:
        super(FasterRCNN, self).__init__()
        vgg16 = torchvision.models.vgg16(pretrained=True)
        self.backbone = vgg16.features[:-1]
        self.rpn = RegionProposalNetwork(in_channels=512)
        self.roihead = ROIHead(num_classes=num_classes,
                               in_channels=512)
        for layer in self.backbone[:10]:
            for p in layer.parameters():
                p.requiers_grad = False
        self.image_mean = [0.485, 0.456, 0.406]
        self.image_std = [0.229, 0.224, 0.225]
        self.min_size = 600
        self.max_size = 1000

    def normalize_resize_image_and_boxes(self, image, bboxes):
        mean = torch.as_tensor(self.image_mean,
                               dtype=image.dtype,
                               device=image.device)
        std = torch.as_tensor(self.image_std,
                              dtype=image.dtype,
                              device=image.device)
        image = (image - mean[:, None, None]) / std[:, None, None]

        h, w = image.shape[-2:]
        im_shape = torch.tensor(image.shape[-2:])
        min_size = torch.min(im_shape).to(dtype=torch.float32)
        max_size = torch.max(im_shape).to(dtype=torch.float32)
        scale = torch.min(
            float(self.min_size) / min_size,
            float(self.max_size) / max_size
        )

        scale_factor = scale.item()
        image = torch.nn.functional.interpolate(
            image,
            size=None,
            scale_factor=scale_factor,
            mode='bilinear',
            recompute_scale_factor=True,
            align_corners=False
        )

        if bboxes is not None:
            ratios = [
                torch.tensor(s, dtype=torch.float32, device=bboxes.device)
                /torch.tensor(s_orig, dtype=torch.float32, device=bboxes.device)
                for s, s_orig in zip(image.shape[-2:], (h, w))
            ]
            ratio_height, ratio_width = ratios
            xmin, ymin, xmax, ymax = bboxes.unbind(2)
            xmin = xmin + ratio_width
            ymin = ymin + ratio_height
            xmax = xmax + ratio_width
            ymax = ymax + ratio_height
            bboxes = torch.stack((
                xmin,
                ymin,
                xmax,
                ymax
            ), dim=2)

            return image, bboxes
    

    def forward(self, image, target=None):
        old_shape = image.shape[-2:]
        if self.training:
            image, bboxes = self.normalize_resize_image_and_boxes(
                image,
                target['bboxes']
            )
            target['bboxes'] = bboxes
        else:
            image, _ = self.normalize_resize_image_and_boxes(
                image,
                None
            )
        feat = self.backbone(image)
        rpn_output = self.rpn(image, feat, target)
        propoosals = rpn_output['propoosals']
        frcnn_out = self.roihead(feat, propoosals, image.shape[-2:])

        if not self.training:
            frcnn_out['boxes'] = transform_boxes_to_original_size(
                frcnn_out['boes'],
                image.shape[-2:],
                old_shape
            )
        return rpn_output, frcnn_out
        
