In [71]:
import torch 
import torch.nn as nn 

import torch.nn.functional as F 
from scipy.optimize import linear_sum_assignment
from torchvision.ops.boxes import box_area
import numpy as np 

def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(-1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=-1)

# modified from torchvision to also return the union


def box_iou(boxes1, boxes2):
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    union = area1[:, None] + area2 - inter

    iou = inter / union
    return iou, union


def generalized_box_iou(boxes1, boxes2):
    """
    Generalized IoU from https://giou.stanford.edu/

    The boxes should be in [x0, y0, x1, y1] format

    Returns a [N, M] pairwise matrix, where N = len(boxes1)
    and M = len(boxes2)
    """
    # degenerate boxes gives inf / nan results
    # so do an early check
    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
    iou, union = box_iou(boxes1, boxes2)

    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])

    wh = (rb - lt).clamp(min=0)  # [N,M,2]
    area = wh[:, :, 0] * wh[:, :, 1]

    return iou - (area - union) / area




def dice_coef(inputs, targets):
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1).unsqueeze(1)
    targets = targets.flatten(1).unsqueeze(0)
    numerator = 2 * (inputs * targets).sum(2)
    denominator = inputs.sum(-1) + targets.sum(-1)

    # NOTE coef doesn't be subtracted to 1 as it is not necessary for computing costs
    coef = (numerator + 1) / (denominator + 1)
    return coef


def dice_loss(inputs, targets, num_boxes):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_boxes


def sigmoid_focal_coef(inputs, targets, alpha: float = 0.25, gamma: float = 2):
    N, M = len(inputs), len(targets)
    inputs = inputs.flatten(1).unsqueeze(1).expand(-1, M, -1)
    targets = targets.flatten(1).unsqueeze(0).expand(N, -1, -1)

    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(
        inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    coef = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        coef = alpha_t * coef

    return coef.mean(2)


def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(
        inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_boxes



class HungarianMatcherIFC(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(
        self,
        cost_class: float = 1,
        cost_dice: float = 1,
        num_classes: int = 80,
    ):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_mask: This is the relative weight of the sigmoid_focal error of the masks in the matching cost
            cost_dice: This is the relative weight of the dice loss of the masks in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_dice = cost_dice
        assert cost_class != 0 or cost_dice != 0, "all costs cant be 0"

        self.num_classes = num_classes
        self.num_cum_classes = [0] + \
            np.cumsum(np.array(num_classes) + 1).tolist()
        self.n_future = 4
    @torch.no_grad()
    def forward(self, outputs, targets):
        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].softmax(-1)
        out_mask = outputs["pred_masks"]
        B, Q, T, s_h, s_w = out_mask.shape
        t_h, t_w = targets[0]["match_masks"].shape[-2:]

        if (s_h, s_w) != (t_h, t_w):
            out_mask = out_mask.reshape(B, Q*T, s_h, s_w)
            out_mask = torch.nn.F.interpolate(out_mask, size=(
                t_h, t_w), mode="bilinear", align_corners=False)
            out_mask = out_mask.view(B, Q, T, t_h, t_w)

        indices = []
        for b_i in range(B):
            b_tgt_ids = targets[b_i]["labels"]
            b_out_prob = out_prob[b_i]

            cost_class = b_out_prob[:, b_tgt_ids]

            b_tgt_mask = targets[b_i]["match_masks"].unsqueeze(0)
            b_out_mask = out_mask[b_i]

            # Compute the dice coefficient cost between masks
            # The 1 is a constant that doesn't change the matching as cost_class, thus omitted.
            
            cost_dice = dice_coef(
                b_out_mask, b_tgt_mask
            ).to(cost_class)

            # Final cost matrix
            C = self.cost_dice * cost_dice + self.cost_class * cost_class

            indices.append(linear_sum_assignment(C.cpu(), maximize=True))

        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    if target.numel() == 0:
        return [torch.zeros([], device=output.device)]
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class SetCriterion(nn.Module):
    """ This class computes the loss for IFC.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth masks and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and mask)
    """

    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_frames):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        self.num_frames = num_frames
        empty_weight = torch.ones(num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, outputs, targets, indices, num_masks, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_masks]
        """
        src_logits = outputs['pred_logits']

        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J]
                                     for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(
            1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - \
                accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_masks):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor(
            [len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) !=
                     pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_masks(self, outputs, targets, indices, num_masks):
        """Compute the losses related to the masks: the focal loss and the dice loss.
           targets dicts must contain the key "masks" containing a tensor of dim [nb_target_masks, h, w]
        """
        assert "pred_masks" in outputs

        idx = self._get_src_permutation_idx(indices)
        src_masks = outputs["pred_masks"][idx]
        target_masks = torch.cat(
            [t['masks'][i] for t, (_, i) in zip(targets, indices)]).to(src_masks)

        n, t = src_masks.shape[:2]
        t_h, t_w = target_masks.shape[-2:]

        src_masks = F.interpolate(src_masks, size=(
            t_h, t_w), mode="bilinear", align_corners=False)

        src_masks = src_masks.flatten(1)
        target_masks = target_masks.flatten(1)

        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
            "loss_dice": dice_loss(src_masks, target_masks, num_masks),
        }
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i)
                              for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i)
                              for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_masks, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_masks, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k,
                               v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target masks accross all nodes, for normalization purposes
        num_masks = sum(len(t["labels"]) for t in targets)
        num_masks = torch.as_tensor(
            [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device)
        # if is_dist_avail_and_initialized():
        #     torch.distributed.all_reduce(num_masks)
        num_masks = torch.clamp(num_masks / 1, min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(
                loss, outputs_without_aux, targets, indices, num_masks))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs = {'log': False}
                    l_dict = self.get_loss(
                        loss, aux_outputs, targets, indices, num_masks, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses



class MaskHeadSmallConv(nn.Module):
    """
    Simple convolutional head, using group norm.
    Upsampling is done using a FPN approach
    """

    def __init__(self, dim, fpn_dims, context_dim, output_dict=None):
        super().__init__()

        inter_dims = [dim, context_dim // 2, context_dim // 4,
                      context_dim // 8, context_dim // 16, context_dim // 64, context_dim // 128]

        gn = 8

        self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn1 = torch.nn.GroupNorm(gn, dim)
        self.lay2 = torch.nn.Conv2d(dim, inter_dims[1], 3, padding=1)
        self.gn2 = torch.nn.GroupNorm(gn, inter_dims[1])
        self.lay3 = torch.nn.Conv2d(inter_dims[1], inter_dims[2], 3, padding=1)
        self.gn3 = torch.nn.GroupNorm(gn, inter_dims[2])
        self.lay4 = torch.nn.Conv2d(inter_dims[2], inter_dims[3], 3, padding=1)
        self.gn4 = torch.nn.GroupNorm(gn, inter_dims[3])
        self.lay5 = torch.nn.Conv2d(inter_dims[3], inter_dims[4], 3, padding=1)
        self.gn5 = torch.nn.GroupNorm(gn, inter_dims[4])

        self.lay6 = torch.nn.Conv2d(inter_dims[4], inter_dims[4], 3, padding=1)
        self.gn6 = torch.nn.GroupNorm(gn, inter_dims[4])

        self.out_lay = torch.nn.Conv2d(
            inter_dims[4], 1, 3, padding=1)  # <- This would be differen

        # if output_dict is not None:
        #     self.future_pred_layers = build_output_convs(
        #         inter_dims[4], output_dict)
        """ 
        outheads_
            - motion_segmentation: 1x5x200x200   - BxFx1xHxW
        """

        self.dim = dim

        self.adapter1 = torch.nn.Conv2d(fpn_dims[0], inter_dims[1], 1)
        self.adapter2 = torch.nn.Conv2d(fpn_dims[1], inter_dims[2], 1)
        self.adapter3 = torch.nn.Conv2d(fpn_dims[2], inter_dims[3], 1)
        self.adapter4 = torch.nn.Conv2d(fpn_dims[3], inter_dims[4], 1)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x, bbox_mask, fpns):

        def expand(tensor, length):
            return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1)
        print(f"Input: {x.shape = } {bbox_mask.shape = }")
        x = torch.cat([expand(x, bbox_mask.shape[1]),
                      bbox_mask.flatten(0, 1)], 1)
        print(f"First Expand: {x.shape = }")
        x = self.lay1(x)
        x = self.gn1(x)
        x = F.relu(x)
        x = self.lay2(x)
        x = self.gn2(x)
        x = F.relu(x)
        #print(f"Before adapter1: {x.shape = }")
        cur_fpn = self.adapter1(fpns[0])
        #print(f"First cur_fpn: {cur_fpn.shape = }")
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
            #print(f"cur_fpn.size(0) != x.size(0): {cur_fpn.shape = }")
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        #print(f"Interpolutaion with expan: {x.shape = }")
        x = self.lay3(x)
        x = self.gn3(x)
        x = F.relu(x)

        #print(f"Before adapter2: {x.shape = }")
        cur_fpn = self.adapter2(fpns[1])
        #print(f"2 adapter2: {cur_fpn.shape = }")
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
            #print(f"cur_fpn.size(0) != x.size(0): {cur_fpn.shape = }")

        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        #print(f"Interpolutaion with expan: {x.shape = }")
        x = self.lay4(x)
        x = self.gn4(x)
        x = F.relu(x)
        #print(f"TBefore adapter3: {x.shape = }")

        cur_fpn = self.adapter3(fpns[2])
        #print(f"after adapter3: {cur_fpn.shape = }")
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
            #print(f"cur_fpn.size(0) != x.size(0): {cur_fpn.shape = }")
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        #print(f"Interpolutaion with expan: {x.shape = }")
        x = self.lay5(x)
        x = self.gn5(x)
        x = F.relu(x)

        #print(f"Fourth Expand: {x.shape = }")
        cur_fpn = self.adapter4(fpns[3])
        #print(f"after adapter4: {cur_fpn.shape = }")
        if cur_fpn.size(0) != x.size(0):
            cur_fpn = expand(cur_fpn, x.size(0) / cur_fpn.size(0))
            #print(f"cur_fpn.size(0) != x.size(0): {cur_fpn.shape = }")
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        #print(f"Interpolutaion with expan: {x.shape = }")

        #print(f"Interpolutaion with expan: {x.shape = }")
        x = self.lay6(x)
        x = self.gn6(x)
        x = F.relu(x)
        print(f"Fourth Expand: {x.shape = }")
        #x = F.interpolate(x, size=200, mode="nearest")
        x = self.out_lay(x)
        print(f"out {x.shape}")
        return x


class MHAttentionMap(nn.Module):
    """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)"""

    def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_dim = hidden_dim
        self.dropout = nn.Dropout(dropout)

        self.q_linear = nn.Linear(query_dim, hidden_dim, bias=bias)
        self.k_linear = nn.Linear(query_dim, hidden_dim, bias=bias)

        nn.init.zeros_(self.k_linear.bias)
        nn.init.zeros_(self.q_linear.bias)
        nn.init.xavier_uniform_(self.k_linear.weight)
        nn.init.xavier_uniform_(self.q_linear.weight)
        self.normalize_fact = float(hidden_dim / self.num_heads) ** -0.5

    def forward(self, q, k, mask=None):
        q = self.q_linear(q)
        k = F.conv2d(
            k, self.k_linear.weight.unsqueeze(-1).unsqueeze(-1), self.k_linear.bias)
        qh = q.view(q.shape[0], q.shape[1], self.num_heads,
                    self.hidden_dim // self.num_heads)
        kh = k.view(k.shape[0], self.num_heads, self.hidden_dim //
                    self.num_heads, k.shape[-2], k.shape[-1])
        weights = torch.einsum("bqnc,bnchw->bqnhw",
                               qh * self.normalize_fact, kh)

        if mask is not None:
            weights.masked_fill_(mask.unsqueeze(1).unsqueeze(1), float("-inf"))
        weights = F.softmax(weights.flatten(2), dim=-1).view_as(weights)
        weights = self.dropout(weights)
        #print(f"MH AttentionMap Shape {weights.shape = }")
        return weights


class PostProcessSegm(nn.Module):
    def __init__(self, threshold=0.5):
        super().__init__()
        self.threshold = threshold

    @torch.no_grad()
    def forward(self, results, outputs, orig_target_sizes, max_target_sizes):
        assert len(orig_target_sizes) == len(max_target_sizes)
        max_h, max_w = max_target_sizes.max(0)[0].tolist()
        #print(f"{max_h = }, {max_w = }")
        out_logits, out_bbox = outputs['pred_logits'], outputs['pred_boxes']

        assert len(out_logits) == len(orig_target_sizes)
        assert orig_target_sizes.shape[1] == 2

        prob = out_logits.sigmoid()
        topk_values, topk_indexes = torch.topk(
            prob.view(out_logits.shape[0], -1), 100, dim=1)
        scores = topk_values
        topk_boxes = topk_indexes // out_logits.shape[2]
        labels = topk_indexes % out_logits.shape[2]
        boxes = box_cxcywh_to_xyxy(out_bbox)
        boxes = torch.gather(
            boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4))

        # and from relative [0, 1] to absolute [0, height] coordinates
        img_h, img_w = orig_target_sizes.unbind(1)

        #print(f"{img_h = }, {img_w = }")
        scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1)
        #print(f" {scale_fct.shape = }")
        boxes = boxes * scale_fct[:, None, :]
        #print(f" {boxes.shape = }")
        out_mask = outputs["pred_masks"]
        #print(f" {out_mask.shape = }")
        B, R, H, W = out_mask.shape
        out_mask = out_mask.view(B, R, H * W)
        #print(f" {out_mask.shape = }")
        out_mask = torch.gather(
            out_mask, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, H * W))
        #print(f"After gather {out_mask.shape = }")
        outputs_masks = out_mask.view(B, 100, H, W).squeeze(2)

        outputs_masks = F.interpolate(outputs_masks, size=(
            max_h, max_w), mode="bilinear", align_corners=False)

        outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu()

        for i, (cur_mask, t, tt) in enumerate(zip(outputs_masks, max_target_sizes, orig_target_sizes)):
            img_h, img_w = t[0], t[1]
            results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
            interpol_tmp = F.interpolate(
                results[i]["masks"].float(), size=tuple(tt.tolist()), mode="nearest"
            )

            results[i]["masks"] = interpol_tmp.byte()

        return results


# """ 
# 1. Get Code and shapes of in-/output  -- 
# 2. Get Matcher for the masks as well as postprocessing & loss function
# 3. Test based on real GTs 
# """

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k)
                                    for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


In [52]:

class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,padding=1,kernel_size=5, activation1= None,activation2=None):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin , kernel_size=kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin , nout, kernel_size=1)
        self.activation1 = activation1
        self.activation2 = activation2
    def forward(self, x):
        out = self.depthwise(x)
        if self.activation1 is not None:
            out = self.activation1(out)
        out = self.pointwise(out)
        if self.activation1 is not None:
            out = self.activation2(out)
        return out

class MaskHeadSmallConv(nn.Module):
    """
    Simple convolutional head, using group norm.
    Upsampling is done using a FPN approach
    """

    def __init__(self, dim, fpn_dims, output_dict=None):
        super().__init__()

        # inter_dims = [dim, context_dim // 2, context_dim // 4,
        #               context_dim // 8, context_dim // 16, context_dim // 64, context_dim // 128]
        self.n_future = 4 
        gn = 8
         
        self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn1 = torch.nn.GroupNorm(gn, dim)
        self.lay2 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn2 = torch.nn.GroupNorm(gn, dim)
        self.lay3 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn3 = torch.nn.GroupNorm(gn, dim)
        self.lay4 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn4 = torch.nn.GroupNorm(gn, dim)
        self.lay5 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn5 = torch.nn.GroupNorm(gn, dim)
        
        self.lay6 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn6 = torch.nn.GroupNorm(gn, dim)

        self.depth_sep_conv2d =  depthwise_separable_conv(dim,dim,kernel_size=5,padding=2, activation1= F.relu,activation2= F.relu)

        # half_dim = dim/2     
        # self.out_lay_1 = torch.nn.Conv2d(
        #     dim, half_dim, 3, padding=1)
        # self.out_lay_2 = torch.nn.Conv2d(
        #     half_dim, 1, 3, padding=1)  # <- This would be differen
        
        self.convert_to_weight = MLP(dim, dim, dim, 3)
        # if output_dict is not None:
        #     self.future_pred_layers = build_output_convs(
        #         inter_dims[4], output_dict)
        """ 
        outheads_
            - motion_segmentation: 1x5x200x200   - BxFx1xHxW
        """

        self.dim = dim

        self.adapter1 = torch.nn.Conv2d(fpn_dims[0], dim, 1)
        self.adapter2 = torch.nn.Conv2d(fpn_dims[1], dim, 1)
        self.adapter3 = torch.nn.Conv2d(fpn_dims[2], dim, 1)
        self.adapter4 = torch.nn.Conv2d(fpn_dims[3], dim, 1)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, src, seg_memory, fpns, hs ):
        x = src + seg_memory
        x = self.lay1(x)
        x = self.gn1(x)
        x = F.relu(x)

        cur_fpn = self.adapter1(fpns[0])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay2(x)
        x = self.gn2(x)
        x = F.relu(x)

        cur_fpn = self.adapter2(fpns[1])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay3(x)
        x = self.gn3(x)
        x = F.relu(x)

        cur_fpn = self.adapter3(fpns[2])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        #print(f"Interpolutaion with expan: {x.shape = }")
        x = self.lay4(x)
        x = self.gn4(x)
        x = F.relu(x)

        cur_fpn = self.adapter4(fpns[3])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay5(x)
        x = self.gn5(x)
        x = F.relu(x)

        T = self.n_future

        x = x.unsqueeze(1).repeat(1,T,1,1,1)
        B, BT, C, H, W = x.shape
        L, B, N, C = hs.shape
        x = self.depth_sep_conv2d(x.view(B*BT, C , H,W)).view(B,BT,C,H,W)

        w = self.convert_to_weight(hs).permute(1,0,2,3)
        w = w.unsqueeze(1).repeat(1,4,1,1,1)

        mask_logits = F.conv2d(x.view(1, BT*C, H, W), w.reshape(B*T*L*N, C, 1, 1), groups=BT)
        mask_logits = mask_logits.view(B, T, L, N, H, W).permute(2, 0, 3, 1, 4, 5)
        return mask_logits


In [53]:
n_future = 4
hidden_dim = 256
nheads = 8
num_queries = 100

gt_instance = torch.randint(low=0,high=2, size=(1, n_future, 200, 200)).to(torch.float32)
seg_memory = torch.rand((1, hidden_dim, 13, 13))
seg_mask = torch.randint(low=0, high=1, size=(1, 13, 13))
hs = torch.rand([3, 1, num_queries, hidden_dim]) # N x B X Q x H <. N layers , B batchsize, query dim , hidden 
init_reference = torch.rand([1, num_queries, 2])
srcs = torch.rand([1, hidden_dim, 13, 13])

features = [
    torch.rand((1, 64, 100, 100)),
    torch.rand((1, 128, 50, 50)),
    torch.rand((1, 256, 25, 25)),
    torch.rand((1, 512, 13, 13)),
]

input_projections = [(features[-1]),
                     (features[-2]), (features[-3]), features[-4]]



bbox_attention = (MHAttentionMap(
    hidden_dim, hidden_dim, nheads, dropout=0))

fpn_dims = [512, 256, 128, 64]



class_mlp = MLP(hidden_dim, hidden_dim, output_dim=num_queries, num_layers=2)




In [54]:
features = [ # with input projection
    torch.rand((1, 256, 100, 100)),
    torch.rand((1, 256, 50, 50)),
    torch.rand((1, 256, 25, 25)),
    torch.rand((1, 256, 13, 13)),
]
input_projections = [(features[-1]),
                     (features[-2]), (features[-3]), features[-4]]
#
fpn_dims = [256, 256, 256, 256]


def _set_aux_loss( outputs_class, outputs_masks):
    # this is a workaround to make torchscript happy, as torchscript
    # doesn't support dictionary with non-homogeneous values, such
    # as a dict having both a Tensor and a list.
    return [{'pred_logits': a, 'pred_masks': b}
            for a, b in zip(outputs_class[:-1], outputs_masks[:-1])]
aux_loss = True 
class_logits_list = []
for i in range(n_future):
    class_logits_list.append( class_mlp(hs[-1]))

outputs_class = torch.stack(class_logits_list)
print(outputs_class.shape)
mask_head = MaskHeadSmallConv(hidden_dim,fpn_dims)

outputs_masks = mask_head(
        srcs, seg_memory, input_projections,hs)


out = {'pred_logits': outputs_class[-1]}
out.update({'pred_masks': outputs_masks[-1]})

if aux_loss:
    out['aux_outputs'] = _set_aux_loss(outputs_class, outputs_masks)


torch.Size([4, 1, 100, 100])


In [49]:
MaxID = 10 
B = 2
Target =torch.randint(low=0, high=high,size=(B,3,5,5)).to(torch.float32) 
temp = torch.arange(MaxID).unsqueeze(0).repeat(B, 1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
Bool_masks  = (temp== Target.unsqueeze(1)).float() 

torch.Size([1, 10, 1, 1, 1])


RuntimeError: The size of tensor a (10) must match the size of tensor b (3) at non-singleton dimension 1

Issue GN does not work with 192 Hidden Dim -> maybe just lin reprojectoin layer 
IFC skips BBox Attn and predicts class -> for me kinda irrelevant since only one class

In [72]:
num_classes = 50 
num_frames = 4
dice_weight=3.0
mask_weight=3.0
no_object_weight = 0.1 
deep_supervision = True
dec_layers = 3


matcher = HungarianMatcherIFC(
    cost_class=1,
    cost_dice=dice_weight,
    num_classes=num_classes,
    )
weight_dict = {"loss_ce": 1, "loss_mask": mask_weight,
                "loss_dice": dice_weight}
if deep_supervision:
    aux_weight_dict = {}
    for i in range(dec_layers - 1):
        aux_weight_dict.update(
            {k + f"_{i}": v for k, v in weight_dict.items()})
    weight_dict.update(aux_weight_dict)
losses = ["labels", "masks", "cardinality"]
criterion = SetCriterion(
    num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses,
    num_frames=num_frames
)


In [61]:
import math
import torch.nn.functional as F 

from projects.mmdet3d_plugin.datasets.utils.warper import FeatureWarper
import os
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)

import numpy as np 
import matplotlib.pyplot as plt 
from mmcv import Config


def import_modules_load_config(cfg_file="beverse_tiny.py", samples_per_gpu=1):
    cfg_path = r"/home/niklas/ETM_BEV/BEVerse/projects/configs"
    cfg_path = os.path.join(cfg_path, cfg_file)

    cfg = Config.fromfile(cfg_path)

    # if args.cfg_options is not None:
    #     cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get("custom_imports", None):
        from mmcv.utils import import_modules_from_strings

        import_modules_from_strings(**cfg["custom_imports"])

    # import modules from plguin/xx, registry will be updated
    if hasattr(cfg, "plugin"):
        if cfg.plugin:
            import importlib

            if hasattr(cfg, "plugin_dir"):
                plugin_dir = cfg.plugin_dir
                _module_dir = os.path.dirname(plugin_dir)
                _module_dir = _module_dir.split("/")
                _module_path = _module_dir[0]

                for m in _module_dir[1:]:
                    _module_path = _module_path + "." + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)
            else:
                # import dir is the dirpath for the config file
                _module_dir = cfg_path
                _module_dir = _module_dir.split("/")
                _module_path = _module_dir[0]
                for m in _module_dir[1:]:
                    _module_path = _module_path + "." + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)

    samples_per_gpu = 1
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1)
        if samples_per_gpu > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.test.pipeline = replace_ImageToTensor(
                cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        samples_per_gpu = max(
            [ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test]
        )
        if samples_per_gpu > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    return cfg


torch.backends.cudnn.benchmark = True
cfg = import_modules_load_config(
    cfg_file=r"beverse_tiny_org.py")

dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
    dataset,
    samples_per_gpu=2,
    workers_per_gpu=cfg.data.workers_per_gpu,
    dist=False,
    shuffle=False)


grid_conf = {
    "xbound": [-50.0, 50.0, 0.5],
    "ybound": [-50.0, 50.0, 0.5],
    "zbound": [-10.0, 10.0, 20.0],
    "dbound": [1.0, 60.0, 1.0],
}

warper = FeatureWarper(grid_conf=grid_conf)


class pseud_class:
    def __init__(self) -> None:
        
        self.receptive_field = 4
        self.warper = FeatureWarper(grid_conf=grid_conf)
        
    def prepare_targets(self, batch,bev_size = (200,200), mask_stride=2,match_stride=2):
        segmentation_labels = batch["motion_segmentation"][0]
        gt_instance = batch["motion_instance"][0]
        future_egomotion = batch["future_egomotions"][0]
        batch_size = len(segmentation_labels)
        labels = {}

        bev_transform = batch.get("aug_transform", None)
        labels["img_is_valid"] = batch.get("img_is_valid", None)

        if bev_transform is not None:
            bev_transform = bev_transform.float()
        #warping so all segmentation labels are inside the current BEV frame FIERY reports better convergence / performance if you do this 
        # segmentation_labels = (
        #     self.warper.cumulative_warp_features_reverse(
        #         segmentation_labels.float().unsqueeze(2),
        #         future_egomotion[:, (self.receptive_field - 1) :],
        #         mode="nearest",
        #         bev_transform=bev_transform,
        #     )
        #     .long()
        #     .contiguous()
        # ).squeeze().to(torch.float32)
        #print(f"Seg labels shape: {segmentation_labels.shape =}")

        
        # Warp instance labels to present's reference frame
        gt_instance = (
            self.warper.cumulative_warp_features_reverse(
                gt_instance.float().unsqueeze(2),
                future_egomotion[:, (self.receptive_field - 1) :],
                mode="nearest",
                bev_transform=bev_transform,
            )
            .long()
            .contiguous()[:, :, 0]
        )
        # better solution by abdur but unsure how to make it work with the rest of the code specifcally maxID since it can be diffferent for batches
        # temp = torch.arange(MaxID).unsqueeze(0).repeat(B, 1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        # gt_masks_ifc_dim  = (temp== Target.unsqueeze(1)).float() 
        target_list = []
        for b in range(batch_size):
            gt_list = []
            ids = len(gt_instance[b].unique())
            for _id in range(ids):
                test_bool = torch.where(gt_instance[b] == _id,1.,0.)
                gt_list.append(test_bool)

            segmentation_labels = torch.stack(gt_list,dim=0)
            
            #segmentation_labels = torch.stack(gt_batch_instances_list,dim=0)
            o_h, o_w = bev_size
            l_h, l_w = math.ceil(o_h/mask_stride), math.ceil(o_w/mask_stride)
            m_h, m_w = math.ceil(o_h/match_stride), math.ceil(o_w/match_stride)

            gt_masks_for_loss  = F.interpolate(segmentation_labels, size=(l_h, l_w), mode="bilinear", align_corners=False)
            gt_masks_for_match = F.interpolate(segmentation_labels, size=(m_h, m_w), mode="bilinear", align_corners=False)

            ids = gt_instance[b].unique() # labels only continous for clip - this is much more of an tracking id as every class is a vehicle anyways # TODO make work with other types of superclasses other then vehicle
            target_list.append({"labels": ids,"masks": gt_masks_for_loss[b], "match_masks": gt_masks_for_match[b], "gt_motion_instance":gt_instance[b] })
        
        return target_list 

projects.mmdet3d_plugin


In [57]:
sample = next(iter(data_loader))

In [62]:
p=pseud_class()
target_list = p.prepare_targets(sample)


In [64]:
target_list[0].keys()

dict_keys(['labels', 'masks', 'match_masks', 'gt_motion_instance'])

In [73]:
loss_dict = criterion(out, target_list)



RuntimeError: weight tensor should be defined either for all or no classes

In [36]:
pred_masks_stacked = torch.stack(out["pred_masks"]).transpose_(1,0)
# list of BxQxHxW
pred_masks_stacked.shape

torch.Size([1, 4, 100, 100, 100])

In [32]:
gt_instances[0]["match_masks"].shape


torch.Size([1, 4, 100, 100])

In [43]:
out_prob = torch.stack(out["pred_logits"]).transpose(1, 0).softmax(-1)


In [38]:
out_prob = out["pred_logits"].softmax(-1)


AttributeError: 'list' object has no attribute 'softmax'

In [58]:
out_mask = torch.stack(out["pred_masks"]).transpose(1, 0)
out_mask.shape

torch.Size([1, 4, 100, 100, 100])

In [59]:
out_prob = torch.stack(out["pred_logits"]).transpose(
    1, 0).softmax(-1)  # BxNxCxQ
out_mask = torch.stack(out["pred_masks"]).transpose(1, 0)
B, T, Q, s_h, s_w = out_mask.shape
t_h, t_w = gt_instances[0]["match_masks"].shape[-2:]

if (s_h, s_w) != (t_h, t_w):
    out_mask = out_mask.reshape(B, Q*T, s_h, s_w)
    out_mask = torch.nn.F.interpolate(out_mask, size=(
        t_h, t_w), mode="bilinear", align_corners=False)
    out_mask = out_mask.view(B, Q, T, t_h, t_w)


In [61]:
match_list = []

#out_prob = out["pred_logits"].softmax(-1)
out_prob = torch.stack(out["pred_logits"]).transpose(1, 0).softmax(-1) # BxNxCxQ 
out_mask = torch.stack(out["pred_masks"]).transpose(1, 0)
B, T, Q, s_h, s_w = out_mask.shape
t_h, t_w = gt_instances[0]["match_masks"].shape[-2:]

if (s_h, s_w) != (t_h, t_w):
    out_mask = out_mask.reshape(B, Q*T, s_h, s_w)
    out_mask = torch.nn.F.interpolate(out_mask, size=(
        t_h, t_w), mode="bilinear", align_corners=False)
    out_mask = out_mask.view(B, Q, T, t_h, t_w)

indices = []
for b_i in range(B):
    b_tgt_ids = gt_instances[b_i]["labels"]
    b_out_prob = out_prob[b_i]

    cost_class = b_out_prob[:, b_tgt_ids]

    b_tgt_mask = gt_instances[b_i]["match_masks"]
    b_out_mask = out_mask[b_i]

    # Compute the dice coefficient cost between masks
    # The 1 is a constant that doesn't change the matching as cost_class, thus omitted.
    cost_dice = dice_coef(
        b_out_mask, b_tgt_mask
    ).to(cost_class)

    # Final cost matrix
    C = cost_dice * cost_dice + cost_class * cost_class

    indices.append(linear_sum_assignment(C.cpu(), maximize=True))

matches = [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(
    j, dtype=torch.int64)) for i, j in indices]
match_list.append(matches)


RuntimeError: The size of tensor a (1000000) must match the size of tensor b (40000) at non-singleton dimension 2

In [64]:
gt_instances[0]["masks"].shape


torch.Size([1, 4, 100, 100])

In [65]:
out_prob = torch.stack(out["pred_logits"]).transpose(
    1, 0).softmax(-1)  # BxNxCxQ
out_mask = torch.stack(out["pred_masks"]).transpose(1, 0)
B, T, Q, s_h, s_w = out_mask.shape
t_h, t_w = gt_instances[0]["match_masks"].shape[-2:]


targets = gt_instances[0]["match_masks"][0]
inputs = out_mask[0]

print(targets.shape, inputs.shape)

inputs = inputs.sigmoid()
inputs = inputs.flatten(1).unsqueeze(1)
targets = targets.flatten(1).unsqueeze(0)
numerator = 2 * (inputs * targets).sum(2)
denominator = inputs.sum(-1) + targets.sum(-1)

# NOTE coef doesn't be subtracted to 1 as it is not necessary for computing costs
coef = (numerator + 1) / (denominator + 1)


torch.Size([4, 100, 100]) torch.Size([4, 100, 100, 100])


RuntimeError: The size of tensor a (1000000) must match the size of tensor b (10000) at non-singleton dimension 2