In [None]:
import torch 
import torch.nn as nn 

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class ConvExtractor(nn.Module):
    def __init__(self, inplanes, planes=[128, 256, 512, 1024], stride=1, groups=1) -> None:
        super().__init__()
        self.conv1 = conv3x3(64, planes[0], kernel_size=3, stride=2)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=1, padding=1)
        self.conv2 = conv3x3(planes[0], planes[1], 2)
        self.conv3 = conv3x3(planes[1], planes[2], 2)
        #self.conv4 = conv3x3(inplanes, planes, 2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        return_dict = {}
        out = self.maxpool(self.relu(self.conv1(x)))
        return_dict["layer1"]
    

torch.rand((1,64,200,200))




In [39]:
import torch 
import torch.nn as nn 

import torch.nn.functional as F 
from scipy.optimize import linear_sum_assignment
from torchvision.ops.boxes import box_area
import numpy as np 


# modified from torchvision to also return the union






def dice_coef(inputs, targets):
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1).unsqueeze(1)
    targets = targets.flatten(1).unsqueeze(0)
    numerator = 2 * (inputs * targets).sum(2)
    denominator = inputs.sum(-1) + targets.sum(-1)

    # NOTE coef doesn't be subtracted to 1 as it is not necessary for computing costs
    coef = (numerator + 1) / (denominator + 1)
    return coef


def dice_loss(inputs, targets, num_boxes):
    """
    Compute the DICE loss, similar to generalized IOU for masks
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
    """
    inputs = inputs.sigmoid()
    inputs = inputs.flatten(1)
    numerator = 2 * (inputs * targets).sum(1)
    denominator = inputs.sum(-1) + targets.sum(-1)
    loss = 1 - (numerator + 1) / (denominator + 1)
    return loss.sum() / num_boxes


def sigmoid_focal_coef(inputs, targets, alpha: float = 0.25, gamma: float = 2):
    N, M = len(inputs), len(targets)
    inputs = inputs.flatten(1).unsqueeze(1).expand(-1, M, -1)
    targets = targets.flatten(1).unsqueeze(0).expand(N, -1, -1)

    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(
        inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    coef = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        coef = alpha_t * coef

    return coef.mean(2)


def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
    Returns:
        Loss tensor
    """
    prob = inputs.sigmoid()
    ce_loss = F.binary_cross_entropy_with_logits(
        inputs, targets, reduction="none")
    p_t = prob * targets + (1 - prob) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    return loss.mean(1).sum() / num_boxes



class HungarianMatcherIFC(nn.Module):
    """This class computes an assignment between the targets and the predictions of the network

    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
    while the others are un-matched (and thus treated as non-objects).
    """

    def __init__(
        self,
        cost_class: float = 1,
        cost_dice: float = 1,
        num_classes: int = 80,
    ):
        """Creates the matcher

        Params:
            cost_class: This is the relative weight of the classification error in the matching cost
            cost_mask: This is the relative weight of the sigmoid_focal error of the masks in the matching cost
            cost_dice: This is the relative weight of the dice loss of the masks in the matching cost
        """
        super().__init__()
        self.cost_class = cost_class
        self.cost_dice = cost_dice
        assert cost_class != 0 or cost_dice != 0, "all costs cant be 0"

        self.num_classes = num_classes
        self.num_cum_classes = [0] + \
            np.cumsum(np.array(num_classes) + 1).tolist()
        self.n_future = 4
    @torch.no_grad()
    def forward(self, outputs, targets):
        # We flatten to compute the cost matrices in a batch
        out_prob = outputs["pred_logits"].softmax(-1)
        out_mask = outputs["pred_masks"]
        B, Q, T, s_h, s_w = out_mask.shape
        t_h, t_w = targets[0]["match_masks"].shape[-2:]

        if (s_h, s_w) != (t_h, t_w):
            out_mask = out_mask.reshape(B, Q*T, s_h, s_w)
            out_mask = torch.nn.F.interpolate(out_mask, size=(
                t_h, t_w), mode="bilinear", align_corners=False)
            out_mask = out_mask.view(B, Q, T, t_h, t_w)

        indices = []
        for b_i in range(B):
            b_tgt_ids = targets[b_i]["labels"]
            b_out_prob = out_prob[b_i]

            cost_class = b_out_prob[:, b_tgt_ids]

            b_tgt_mask = targets[b_i]["match_masks"]
            b_out_mask = out_mask[b_i]

            # Compute the dice coefficient cost between masks
            # The 1 is a constant that doesn't change the matching as cost_class, thus omitted.
            
            cost_dice = dice_coef(
                b_out_mask, b_tgt_mask
            ).to(cost_class)

            print(f"{cost_dice.shape = } {cost_class.shape = } {b_tgt_ids.shape = } {b_out_prob.shape = } {b_tgt_mask.shape = } {b_out_mask.shape = }")
            # Final cost matrix
            C = self.cost_dice * cost_dice + self.cost_class * cost_class

            indices.append(linear_sum_assignment(C.cpu(), maximize=True))

        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]


@torch.no_grad()
def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    if target.numel() == 0:
        return [torch.zeros([], device=output.device)]
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

class SetCriterion(nn.Module):
    """ This class computes the loss for IFC.
    The process happens in two steps:
        1) we compute hungarian assignment between ground truth masks and the outputs of the model
        2) we supervise each pair of matched ground-truth / prediction (supervise class and mask)
    """

    def __init__(self, num_classes, matcher, weight_dict, eos_coef, losses, num_frames):
        """ Create the criterion.
        Parameters:
            num_classes: number of object categories, omitting the special no-object category
            matcher: module able to compute a matching between targets and proposals
            weight_dict: dict containing as key the names of the losses and as values their relative weight.
            eos_coef: relative classification weight applied to the no-object category
            losses: list of all the losses to be applied. See get_loss for list of available losses.
        """
        super().__init__()
        self.num_classes = num_classes
        self.matcher = matcher
        self.weight_dict = weight_dict
        self.eos_coef = eos_coef
        self.losses = losses
        self.num_frames = num_frames
        empty_weight = torch.ones(num_classes + 1)
        empty_weight[-1] = self.eos_coef
        self.register_buffer('empty_weight', empty_weight)

    def loss_labels(self, outputs, targets, indices, num_masks, log=True):
        """Classification loss (NLL)
        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_masks]
        """
        src_logits = outputs['pred_logits']
        idx = self._get_src_permutation_idx(indices)
        target_classes_o = torch.cat([t["labels"][J]
                                     for t, (_, J) in zip(targets, indices)])
        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                    dtype=torch.int64, device=src_logits.device)
        target_classes[idx] = target_classes_o

        loss_ce = F.cross_entropy(src_logits.transpose(
            1, 2), target_classes, self.empty_weight)
        losses = {'loss_ce': loss_ce}

        if log:
            # TODO this should probably be a separate loss, not hacked in this one here
            losses['class_error'] = 100 - \
                accuracy(src_logits[idx], target_classes_o)[0]
        return losses

    @torch.no_grad()
    def loss_cardinality(self, outputs, targets, indices, num_masks):
        """ Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes
        This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients
        """
        pred_logits = outputs['pred_logits']
        device = pred_logits.device
        tgt_lengths = torch.as_tensor(
            [len(v["labels"]) for v in targets], device=device)
        # Count the number of predictions that are NOT "no-object" (which is the last class)
        card_pred = (pred_logits.argmax(-1) !=
                     pred_logits.shape[-1] - 1).sum(1)
        card_err = F.l1_loss(card_pred.float(), tgt_lengths.float())
        losses = {'cardinality_error': card_err}
        return losses

    def loss_masks(self, outputs, targets, indices, num_masks):
        """Compute the losses related to the masks: the focal loss and the dice loss.
           targets dicts must contain the key "masks" containing a tensor of dim [nb_target_masks, h, w]
        """
        assert "pred_masks" in outputs

        idx = self._get_src_permutation_idx(indices)
        src_masks = outputs["pred_masks"][idx]

        target_masks = torch.cat(
            [t['masks'][i] for t, (_, i) in zip(targets, indices)]).to(src_masks)

        n, t = src_masks.shape[:2]
        t_h, t_w = target_masks.shape[-2:]

        src_masks = F.interpolate(src_masks, size=(
            t_h, t_w), mode="bilinear", align_corners=False)

        src_masks = src_masks.flatten(1)
        target_masks = target_masks.flatten(1)

        losses = {
            "loss_mask": sigmoid_focal_loss(src_masks, target_masks, num_masks),
            "loss_dice": dice_loss(src_masks, target_masks, num_masks),
        }
        return losses

    def _get_src_permutation_idx(self, indices):
        # permute predictions following indices
        batch_idx = torch.cat([torch.full_like(src, i)
                              for i, (src, _) in enumerate(indices)])
        src_idx = torch.cat([src for (src, _) in indices])
        return batch_idx, src_idx

    def _get_tgt_permutation_idx(self, indices):
        # permute targets following indices
        batch_idx = torch.cat([torch.full_like(tgt, i)
                              for i, (_, tgt) in enumerate(indices)])
        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
        return batch_idx, tgt_idx

    def get_loss(self, loss, outputs, targets, indices, num_masks, **kwargs):
        loss_map = {
            'labels': self.loss_labels,
            'cardinality': self.loss_cardinality,
            'masks': self.loss_masks
        }
        assert loss in loss_map, f'do you really want to compute {loss} loss?'
        return loss_map[loss](outputs, targets, indices, num_masks, **kwargs)

    def forward(self, outputs, targets):
        """ This performs the loss computation.
        Parameters:
             outputs: dict of tensors, see the output specification of the model for the format
             targets: list of dicts, such that len(targets) == batch_size.
                      The expected keys in each dict depends on the losses applied, see each loss' doc
        """
        outputs_without_aux = {k: v for k,
                               v in outputs.items() if k != 'aux_outputs'}

        # Retrieve the matching between the outputs of the last layer and the targets
        indices = self.matcher(outputs_without_aux, targets)

        # Compute the average number of target masks accross all nodes, for normalization purposes
        num_masks = sum(len(t["labels"]) for t in targets)
        num_masks = torch.as_tensor(
            [num_masks], dtype=torch.float, device=next(iter(outputs.values())).device)
        # if is_dist_avail_and_initialized():
        #     torch.distributed.all_reduce(num_masks)
        num_masks = torch.clamp(num_masks / 1, min=1).item()

        # Compute all the requested losses
        losses = {}
        for loss in self.losses:
            losses.update(self.get_loss(
                loss, outputs_without_aux, targets, indices, num_masks))

        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
        if 'aux_outputs' in outputs:
            for i, aux_outputs in enumerate(outputs['aux_outputs']):
                indices = self.matcher(aux_outputs, targets)
                for loss in self.losses:
                    kwargs = {}
                    if loss == 'labels':
                        # Logging is enabled only for the last layer
                        kwargs = {'log': False}
                    l_dict = self.get_loss(
                        loss, aux_outputs, targets, indices, num_masks, **kwargs)
                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
                    losses.update(l_dict)

        return losses



# """ 
# 1. Get Code and shapes of in-/output  -- 
# 2. Get Matcher for the masks as well as postprocessing & loss function
# 3. Test based on real GTs 
# """




In [46]:
class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k)
                                    for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x
class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout,padding=1,kernel_size=5, activation1= None,activation2=None):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(nin, nin , kernel_size=kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin , nout, kernel_size=1)
        self.activation1 = activation1
        self.activation2 = activation2
    def forward(self, x):
        out = self.depthwise(x)
        if self.activation1 is not None:
            out = self.activation1(out)
        out = self.pointwise(out)
        if self.activation1 is not None:
            out = self.activation2(out)
        return out

class MaskHeadSmallConv(nn.Module):
    """
    Simple convolutional head, using group norm.
    Upsampling is done using a FPN approach
    """

    def __init__(self, dim, fpn_dims, output_dict=None):
        super().__init__()

        # inter_dims = [dim, context_dim // 2, context_dim // 4,
        #               context_dim // 8, context_dim // 16, context_dim // 64, context_dim // 128]
        self.n_future = 5 
        gn = 8
         
        self.lay1 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn1 = torch.nn.GroupNorm(gn, dim)
        self.lay2 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn2 = torch.nn.GroupNorm(gn, dim)
        self.lay3 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn3 = torch.nn.GroupNorm(gn, dim)
        self.lay4 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn4 = torch.nn.GroupNorm(gn, dim)
        self.lay5 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        self.gn5 = torch.nn.GroupNorm(gn, dim)
        
        # self.lay6 = torch.nn.Conv2d(dim, dim, 3, padding=1)
        # self.gn6 = torch.nn.GroupNorm(gn, dim)

        self.depth_sep_conv2d =  depthwise_separable_conv(dim,dim,kernel_size=5,padding=2, activation1= F.relu,activation2= F.relu)

        # half_dim = dim/2     
        # self.out_lay_1 = torch.nn.Conv2d(
        #     dim, half_dim, 3, padding=1)
        # self.out_lay_2 = torch.nn.Conv2d(
        #     half_dim, 1, 3, padding=1)  # <- This would be differen
        
        self.convert_to_weight = MLP(dim, dim, dim, 3)
        # if output_dict is not None:
        #     self.future_pred_layers = build_output_convs(
        #         inter_dims[4], output_dict)
        """ 
        outheads_
            - motion_segmentation: 1x5x200x200   - BxFx1xHxW
        """

        self.dim = dim

        self.adapter1 = torch.nn.Conv2d(fpn_dims[0], dim, 1)
        self.adapter2 = torch.nn.Conv2d(fpn_dims[1], dim, 1)
        self.adapter3 = torch.nn.Conv2d(fpn_dims[2], dim, 1)
        self.adapter4 = torch.nn.Conv2d(fpn_dims[3], dim, 1)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_uniform_(m.weight, a=1)
                nn.init.constant_(m.bias, 0)

    def forward(self, src, seg_memory, fpns, hs ):
        x = src + seg_memory
        x = self.lay1(x)
        x = self.gn1(x)
        x = F.relu(x)

        cur_fpn = self.adapter1(fpns[0])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay2(x)
        x = self.gn2(x)
        x = F.relu(x)

        cur_fpn = self.adapter2(fpns[1])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay3(x)
        x = self.gn3(x)
        x = F.relu(x)

        cur_fpn = self.adapter3(fpns[2])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        #print(f"Interpolutaion with expan: {x.shape = }")
        x = self.lay4(x)
        x = self.gn4(x)
        x = F.relu(x)

        cur_fpn = self.adapter4(fpns[3])
        x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
        x = self.lay5(x)
        x = self.gn5(x)
        x = F.relu(x)

        T = self.n_future

        x = x.unsqueeze(1).repeat(1,T,1,1,1)
        B, BT, C, H, W = x.shape
        L, B, N, C = hs.shape

        x = self.depth_sep_conv2d(x.view(B*BT, C , H,W)).view(B,BT,C,H,W)


        w = self.convert_to_weight(hs).permute(1,0,2,3)
        print(w.shape)
        w = w.unsqueeze(1).repeat(1,T,1,1,1)
        print(w.shape)
        
        
        mask_logits = F.conv2d(x.view(1, BT*C, H, W), w.reshape(B*T*L*N, C, 1, 1), groups=BT)
        mask_logits = mask_logits.view(B, T, L, N, H, W).permute(2, 0, 3, 1, 4, 5)
        return mask_logits


In [27]:
import torch 
import torch.nn as nn 
from torch.nn import functional as F 

class MLP(nn.Module):
    """ Very simple multi-layer perceptron (also called FFN)"""

    def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
        super().__init__()
        self.num_layers = num_layers
        h = [hidden_dim] * (num_layers - 1)
        self.layers = nn.ModuleList(nn.Linear(n, k)
                                    for n, k in zip([input_dim] + h, h + [output_dim]))

    def forward(self, x):
        for i, layer in enumerate(self.layers):
            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
        return x


class depthwise_separable_conv(nn.Module):
    def __init__(self, nin, nout, padding=1, kernel_size=5, activation1=None, activation2=None):
        super(depthwise_separable_conv, self).__init__()
        self.depthwise = nn.Conv2d(
            nin, nin, kernel_size=kernel_size, padding=padding, groups=nin)
        self.pointwise = nn.Conv2d(nin, nout, kernel_size=1)
        self.activation1 = activation1
        self.activation2 = activation2

    def forward(self, x):
        out = self.depthwise(x)
        if self.activation1 is not None:
            out = self.activation1(out)
        out = self.pointwise(out)
        if self.activation1 is not None:
            out = self.activation2(out)
        return out
    
dim = 256
gn=8
num_queries = 300
hidden_dim = dim 
T = 5 
hs = torch.rand([6, 2, num_queries, hidden_dim])
fpn3 = torch.rand((2, 256, 50, 50))
fpn4 = torch.rand((2, 512, 100, 100))


lay4 = torch.nn.Conv2d(dim, dim*T, 3, padding=1)
gn4 = torch.nn.GroupNorm(gn, dim*T)
lay5 = torch.nn.Conv2d(dim*2, dim*T, 3, padding=1)
gn5 = torch.nn.GroupNorm(gn, dim*T)
adapter3 = torch.nn.Conv2d(256, dim, 1)
adapter4 = torch.nn.Conv2d(dim*2, dim*T, 1)
convert_to_weight = MLP(dim, dim, dim*T, 2)
depth_sep_conv2d = depthwise_separable_conv(
    dim, dim, kernel_size=5, padding=2, activation1=F.relu, activation2=F.relu)


a = nn.Sequential(
    nn.Conv3d(dim, dim, kernel_size=[1, 1, 1], bias=False),
    nn.BatchNorm3d(
        num_features=dim, eps=1e-5, momentum=0.1
    ),
    nn.ReLU(inplace=True),
)

# Depthwise (channel-separated) 3x3x3x1 conv
# Depthwise (channel-separated) 1x3x3x1 spatial conv
b1 = nn.Conv3d(
    dim,
    dim,
    kernel_size=[1, 3, 3],
    stride=[1, 1, 1],
    padding=[0, 1, 1],
    bias=False,
)
# Depthwise (channel-separated) 3x1x1x1 temporal conv
b2 = nn.Conv3d(
    dim,
    dim,
    kernel_size=[3, 1, 1],
    stride=[1, 1, 1],
    padding=[1, 0, 0],
    bias=False,
)


x = torch.rand((1, 256, 25, 25))
print(f"input {fpn3.shape = }, {x.shape}")
cur_fpn = adapter3(fpn3)
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
#print(f"Interpolutaion with expan: {x.shape = }")
print(f"interpolation, {x.shape}")
x = lay4(x)
x = gn4(x)
x = F.relu(x)

print(f"after adapter1 {x.shape }")
cur_fpn = adapter4(fpn4)
print(f" adapter2 {cur_fpn.shape } {x.shape}")
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")


T = 5
H, W = x.shape[-2:]
B = 2
# x.unsqueeze(1).reshape(1, T, -1, H, W)
x = x.unsqueeze(1).reshape(B, -1, T, H, W)
x = b1(x)
x = b2(x)
x = F.relu(x)
x = a(x).permute(0, 2, 1, 3, 4)

B, BT, C, H, W = x.shape
L, B, N, C = hs.shape
print(f"{B}, {BT} {T}")
# x = depth_sep_conv2d(x.view(B*BT, C, H, W)).view(B, BT, C, H, W)

print(f"after reshape {x.shape }")
print(f"HS input {hs.shape }")
w = convert_to_weight(hs).permute(1, 0, 2, 3)
print(f"after weight {w.shape }")
#torch.Size([1, 6, 100, 256])
#torch.Size([1, 5, 6, 100, 256])
w = w.unsqueeze(1).reshape(B, T, L, N, -1)
print(f"after reshape {w.shape }")
print(x.shape)
x = x.reshape(1, B*BT*C, H, W)
mask_logits = F.conv2d(x,
                       w.reshape(B*T*L*N, C, 1, 1), groups=T*B)
print(f"mask logits {mask_logits.shape }")
mask_logits = mask_logits.view(
    B, T, L, N, H, W).permute(2, 0, 3, 1, 4, 5)
print(f"mask logits {mask_logits.shape }")


input fpn3.shape = torch.Size([2, 256, 50, 50]), torch.Size([1, 256, 25, 25])
interpolation, torch.Size([2, 256, 50, 50])
after adapter1 torch.Size([2, 1280, 50, 50])
 adapter2 torch.Size([2, 1280, 100, 100]) torch.Size([2, 1280, 50, 50])
2, 5 5
after reshape torch.Size([2, 5, 256, 100, 100])
HS input torch.Size([6, 2, 300, 256])
after weight torch.Size([2, 6, 300, 1280])
after reshape torch.Size([2, 5, 6, 300, 256])
torch.Size([2, 5, 256, 100, 100])
mask logits torch.Size([1, 18000, 100, 100])
mask logits torch.Size([6, 2, 300, 5, 100, 100])


In [16]:
x = torch.rand((2, 256, 25, 25))
print(f"input {fpn3.shape = }, {x.shape}")
cur_fpn = adapter3(fpn3)
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")
#print(f"Interpolutaion with expan: {x.shape = }")
print(f"interpolation, {x.shape}")
x = lay4(x)
x = gn4(x)
x = F.relu(x)

print(f"after adapter1 {x.shape }")
cur_fpn = adapter4(fpn4)
print(f" adapter2 {cur_fpn.shape } {x.shape}")
x = cur_fpn + F.interpolate(x, size=cur_fpn.shape[-2:], mode="nearest")


T = 5
H, W = x.shape[-2:]
# x.unsqueeze(1).reshape(1, T, -1, H, W)
x = x.unsqueeze(1).reshape(1, -1, T, H, W)
x= b1(x)
x=b2(x)
x = F.relu(x)
x = a(x).permute(0, 2, 1, 3, 4)

B, BT, C, H, W = x.shape
L, B, N, C = hs.shape

# x = depth_sep_conv2d(x.view(B*BT, C, H, W)).view(B, BT, C, H, W)

print(f"after reshape {x.shape }")
print(f"HS input {hs.shape }")
w = convert_to_weight(hs).permute(1, 0, 2, 3)
print(f"after weight {w.shape }")
#torch.Size([1, 6, 100, 256])
#torch.Size([1, 5, 6, 100, 256])
w = w.unsqueeze(1).reshape(1, T, L, N, -1)
print(f"after reshape {w.shape }")
print(x.shape)
x = x.reshape(1, BT*C, H, W)
mask_logits = F.conv2d(x,
                       w.reshape(B*T*L*N, C, 1, 1), groups=BT)
print(f"mask logits {w.shape }")
mask_logits = mask_logits.view(
    B, T, L, N, H, W).permute(2, 0, 3, 1, 4, 5)


input fpn3.shape = torch.Size([1, 256, 50, 50]), torch.Size([1, 256, 25, 25])
interpolation, torch.Size([1, 256, 50, 50])
after adapter1 torch.Size([1, 1280, 50, 50])
 adapter2 torch.Size([1, 1280, 100, 100]) torch.Size([1, 1280, 50, 50])
after reshape torch.Size([1, 5, 256, 100, 100])
HS input torch.Size([6, 1, 500, 256])
after weight torch.Size([1, 6, 500, 1280])
after reshape torch.Size([1, 5, 6, 500, 256])
torch.Size([1, 5, 256, 100, 100])
mask logits torch.Size([1, 5, 6, 500, 256])


In [27]:
test_reshape = torch.rand((1, 1280, 100, 100))
test_reshape= test_reshape.unsqueeze(1)
test_reshape.reshape(1,T,-1,100,100).shape

torch.Size([1, 5, 256, 100, 100])

In [30]:
test_reshape = torch.rand((1, 1280, 100, 100))
test_reshape.shape[-2:]


torch.Size([100, 100])

In [47]:
n_future = 5
hidden_dim = 256
nheads = 8
num_queries = 100
num_classes = 50 

gt_instance = torch.randint(low=0,high=2, size=(1, n_future, 200, 200)).to(torch.float32)
seg_memory = torch.rand((1, hidden_dim, 13, 13))
seg_mask = torch.randint(low=0, high=1, size=(1, 13, 13))
hs = torch.rand([6, 1, num_queries, hidden_dim]) # N x B X Q x H <. N layers , B batchsize, query dim , hidden 
init_reference = torch.rand([1, num_queries, 2])
srcs = torch.rand([1, hidden_dim, 13, 13])

class_mlp = MLP(hidden_dim, hidden_dim, output_dim=num_classes + 1, num_layers=2)

features = [ # with input projection
    torch.rand((1, 256, 100, 100)), #64
    torch.rand((1, 256, 50, 50)), #128
    torch.rand((1, 256, 25, 25)), # 256
    torch.rand((1, 256, 13, 13)), # 512 
]
input_projections = [(features[-1]),
                     (features[-2]), (features[-3]), features[-4]]
#
fpn_dims = [256, 256, 256, 256]


def _set_aux_loss( outputs_class, outputs_masks):
    # this is a workaround to make torchscript happy, as torchscript
    # doesn't support dictionary with non-homogeneous values, such
    # as a dict having both a Tensor and a list.
    return [{'pred_logits': a, 'pred_masks': b}
            for a, b in zip(outputs_class[:-1], outputs_masks[:-1])]
aux_loss = True 
class_logits_list = []
for i in range(n_future):
    class_logits_list.append( class_mlp(hs[-1]))

outputs_class = torch.stack(class_logits_list)
print(outputs_class.shape)
mask_head = MaskHeadSmallConv(hidden_dim,fpn_dims)

outputs_masks = mask_head(
        srcs, seg_memory, input_projections,hs)


out = {'pred_logits': outputs_class[-1]}
out.update({'pred_masks': outputs_masks[-1]})

if aux_loss:
    out['aux_outputs'] = _set_aux_loss(outputs_class, outputs_masks)


torch.Size([5, 1, 100, 51])
torch.Size([1, 6, 100, 256])
torch.Size([1, 5, 6, 100, 256])


In [23]:

num_frames = 5
dice_weight=3.0
mask_weight=3.0
no_object_weight = 0.1 
deep_supervision = True
dec_layers = 3


matcher = HungarianMatcherIFC(
    cost_class=1,
    cost_dice=dice_weight,
    num_classes=num_classes,
    )
weight_dict = {"loss_ce": 1, "loss_mask": mask_weight,
                "loss_dice": dice_weight}
if deep_supervision:
    aux_weight_dict = {}
    for i in range(dec_layers - 1):
        aux_weight_dict.update(
            {k + f"_{i}": v for k, v in weight_dict.items()})
    weight_dict.update(aux_weight_dict)
losses = ["labels", "masks", "cardinality"]
criterion = SetCriterion(
    num_classes, matcher=matcher, weight_dict=weight_dict, eos_coef=no_object_weight, losses=losses,
    num_frames=num_frames
)


In [5]:
import math
import torch.nn.functional as F 

from projects.mmdet3d_plugin.datasets.utils.warper import FeatureWarper
import os
from mmdet.datasets import (build_dataloader, build_dataset,
                            replace_ImageToTensor)

import numpy as np 
import matplotlib.pyplot as plt 
from mmcv import Config


def import_modules_load_config(cfg_file="beverse_tiny.py", samples_per_gpu=1):
    cfg_path = r"/home/niklas/ETM_BEV/BEVerse/projects/configs"
    cfg_path = os.path.join(cfg_path, cfg_file)

    cfg = Config.fromfile(cfg_path)

    # if args.cfg_options is not None:
    #     cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get("custom_imports", None):
        from mmcv.utils import import_modules_from_strings

        import_modules_from_strings(**cfg["custom_imports"])

    # import modules from plguin/xx, registry will be updated
    if hasattr(cfg, "plugin"):
        if cfg.plugin:
            import importlib

            if hasattr(cfg, "plugin_dir"):
                plugin_dir = cfg.plugin_dir
                _module_dir = os.path.dirname(plugin_dir)
                _module_dir = _module_dir.split("/")
                _module_path = _module_dir[0]

                for m in _module_dir[1:]:
                    _module_path = _module_path + "." + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)
            else:
                # import dir is the dirpath for the config file
                _module_dir = cfg_path
                _module_dir = _module_dir.split("/")
                _module_path = _module_dir[0]
                for m in _module_dir[1:]:
                    _module_path = _module_path + "." + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)

    samples_per_gpu = 1
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        samples_per_gpu = cfg.data.test.pop("samples_per_gpu", 1)
        if samples_per_gpu > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.test.pipeline = replace_ImageToTensor(
                cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        samples_per_gpu = max(
            [ds_cfg.pop("samples_per_gpu", 1) for ds_cfg in cfg.data.test]
        )
        if samples_per_gpu > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    return cfg


torch.backends.cudnn.benchmark = True
cfg = import_modules_load_config(
    cfg_file=r"beverse_tiny_org.py")

dataset = build_dataset(cfg.data.test)
data_loader = build_dataloader(
    dataset,
    samples_per_gpu=2,
    workers_per_gpu=cfg.data.workers_per_gpu,
    dist=False,
    shuffle=False)


grid_conf = {
    "xbound": [-50.0, 50.0, 0.5],
    "ybound": [-50.0, 50.0, 0.5],
    "zbound": [-10.0, 10.0, 20.0],
    "dbound": [1.0, 60.0, 1.0],
}

warper = FeatureWarper(grid_conf=grid_conf)


class pseud_class:
    def __init__(self) -> None:
        
        self.receptive_field = 3
        self.warper = FeatureWarper(grid_conf=grid_conf)
        
    def prepare_targets(self, batch,bev_size = (200,200), mask_stride=2,match_stride=2):
        segmentation_labels = batch["motion_segmentation"][0]
        gt_instance = batch["motion_instance"][0]
        future_egomotion = batch["future_egomotions"][0]
        batch_size = len(segmentation_labels)
        labels = {}

        bev_transform = batch.get("aug_transform", None)
        labels["img_is_valid"] = batch.get("img_is_valid", None)

        if bev_transform is not None:
            bev_transform = bev_transform.float()

        # Warp instance labels to present's reference frame
        gt_instance = (
            self.warper.cumulative_warp_features_reverse(
                gt_instance.float().unsqueeze(2),
                future_egomotion[:, (self.receptive_field - 1) :],
                mode="nearest",
                bev_transform=bev_transform,
            )
            .long()
            .contiguous()[:, :, 0]
        )
        # better solution by abdur but unsure how to make it work with the rest of the code specifcally maxID since it can be diffferent for batches
        # temp = torch.arange(MaxID).unsqueeze(0).repeat(B, 1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
        # gt_masks_ifc_dim  = (temp== Target.unsqueeze(1)).float() 
        target_list = []
        for b in range(batch_size):
            gt_list = []
            ids = len(gt_instance[b].unique())
            for _id in range(ids):
                test_bool = torch.where(gt_instance[b] == _id,1.,0.)
                gt_list.append(test_bool)

            segmentation_labels = torch.stack(gt_list,dim=0)
            
            #segmentation_labels = torch.stack(gt_batch_instances_list,dim=0)
            o_h, o_w = bev_size
            l_h, l_w = math.ceil(o_h/mask_stride), math.ceil(o_w/mask_stride)
            m_h, m_w = math.ceil(o_h/match_stride), math.ceil(o_w/match_stride)

            gt_masks_for_loss  = F.interpolate(segmentation_labels, size=(l_h, l_w), mode="bilinear", align_corners=False)
            gt_masks_for_match = F.interpolate(segmentation_labels, size=(m_h, m_w), mode="bilinear", align_corners=False)
            
            ids = gt_instance[b].unique() # labels only continous for clip - this is much more of an tracking id as every class is a vehicle anyways # TODO make work with other types of superclasses other then vehicle
            target_list.append({"labels": ids,"masks": gt_masks_for_loss, "match_masks": gt_masks_for_match, "gt_motion_instance":gt_instance[b] })
        return target_list 

projects.mmdet3d_plugin


In [7]:
sample = next(iter(data_loader))

In [8]:
p=pseud_class()
target_list = p.prepare_targets(sample)


In [13]:
(target_list[0]["match_masks"].shape)


torch.Size([15, 5, 100, 100])

In [24]:
loss_dict = criterion(out, target_list)



cost_dice.shape = torch.Size([100, 15]) cost_class.shape = torch.Size([100, 15]) b_tgt_ids.shape = torch.Size([15]) b_out_prob.shape = torch.Size([100, 51]) b_tgt_mask.shape = torch.Size([15, 5, 100, 100]) b_out_mask.shape = torch.Size([100, 5, 100, 100])
cost_dice.shape = torch.Size([100, 15]) cost_class.shape = torch.Size([100, 15]) b_tgt_ids.shape = torch.Size([15]) b_out_prob.shape = torch.Size([100, 51]) b_tgt_mask.shape = torch.Size([15, 5, 100, 100]) b_out_mask.shape = torch.Size([100, 5, 100, 100])
cost_dice.shape = torch.Size([100, 15]) cost_class.shape = torch.Size([100, 15]) b_tgt_ids.shape = torch.Size([15]) b_out_prob.shape = torch.Size([100, 51]) b_tgt_mask.shape = torch.Size([15, 5, 100, 100]) b_out_mask.shape = torch.Size([100, 5, 100, 100])
cost_dice.shape = torch.Size([100, 15]) cost_class.shape = torch.Size([100, 15]) b_tgt_ids.shape = torch.Size([15]) b_out_prob.shape = torch.Size([100, 51]) b_tgt_mask.shape = torch.Size([15, 5, 100, 100]) b_out_mask.shape = torch.S

target class torch.Size([16, 100])
src_logits transposed 1 2 torch.Size([16, 100, 41])
empty weight torch.Size([41])

In [36]:
pred_masks_stacked = torch.stack(out["pred_masks"]).transpose_(1,0)
# list of BxQxHxW
pred_masks_stacked.shape

torch.Size([1, 4, 100, 100, 100])

In [32]:
gt_instances[0]["match_masks"].shape


torch.Size([1, 4, 100, 100])

In [43]:
out_prob = torch.stack(out["pred_logits"]).transpose(1, 0).softmax(-1)


In [38]:
out_prob = out["pred_logits"].softmax(-1)


AttributeError: 'list' object has no attribute 'softmax'

In [58]:
out_mask = torch.stack(out["pred_masks"]).transpose(1, 0)
out_mask.shape

torch.Size([1, 4, 100, 100, 100])