In [None]:
import os

# os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'
# os.environ["CUDA_VISIBLE_DEVICES"]="0"

import numpy as np
import torch
import torch.utils.data
from PIL import Image

import cv2

import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor

from glob import glob
from tqdm.autonotebook import tqdm

import torch.nn as nn
import torch.nn.functional as F

import open3d as o3d

import trimesh

import multiprocessing as mp

import yaml

import matplotlib.pyplot as plt

import h5py

import random

import roma

# from epd import *
from tqdm.autonotebook import tqdm
import scipy

device = 'cuda'

## Load Preprocessed Data

In [None]:
hf = h5py.File('processed_data_subs.h5', 'r')
rgbs = hf['rgbs'][:]
depth_imgs = hf['depth_imgs'][:]
merged_masks = hf['merged_masks'][:]
point_cloud_pairs = np.load('point_cloud_pairs.npy',allow_pickle=True)
gt = np.load('custom_gt.npy',allow_pickle=True).item()

In [None]:
obj_mask_value = {14: 0, 28: 1, 42: 2, 56: 3, 70: 4, 85: 5, 99: 6, 113: 7, 127: 8, 141: 9, 155: 10, 170: 11, 184: 12, 198: 13, 212: 14, 226: 15, 240: 16, 255: 17}
obj_to_mask_value = {v: k for k, v in obj_mask_value.items()}

reference_point_clouds = np.load('reference_point_clouds_custom.npy',allow_pickle=True)

## Visualize Data

### Reference Point Clouds

In [None]:
scene = trimesh.Scene()
for i,pc in enumerate(reference_point_clouds):
    scene.add_geometry(trimesh.points.PointCloud(pc[0]*1000 + [i%6 * 300, i//6 * 300 ,0],colors=pc[1]))
scene.camera.z_far = 10
scene.show(viewer='notebook')

In [None]:
diameters = []
for i in range(len(reference_point_clouds)):
    diameters.append(float(torch.cdist(torch.Tensor(reference_point_clouds[i][0]),torch.Tensor(reference_point_clouds[i][0])).max().detach().numpy())*1000.0)

In [None]:
idx_list = []
for i in gt.keys():
    idx_list += [i] * len(gt[i])

### Sample Data

In [None]:
rind = np.random.choice(8000)
point_cloud_pair = point_cloud_pairs[rind]
rgb = rgbs[idx_list[rind]]
merged_mask = merged_masks[idx_list[rind]]
depth_img = depth_imgs[idx_list[rind]]

In [None]:
plt.figure(figsize=(20,10))
plt.subplot(1,2,1)
plt.imshow(rgb)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(merged_mask,cmap='Greys_r')
plt.axis('off')

In [None]:
scene = trimesh.Scene()
dp_i = np.copy(depth_img[:,:,0:3])
dp_i[np.isnan(dp_i)] = 0.0
scene.add_geometry(trimesh.points.PointCloud(dp_i.reshape([-1,3])*1000.0,colors=rgb.reshape([-1,3])))
scene.camera.z_far = 10
scene.show(viewer='notebook')

In [None]:
scene = trimesh.Scene()
# obj_id_list = list(id_to_name.keys())

pc_source = trimesh.points.PointCloud(point_cloud_pair[0],colors=point_cloud_pair[1])
scene.add_geometry(pc_source)

r_pc = np.copy(reference_point_clouds[point_cloud_pair[2]-1][0])
r_pc = (np.reshape(point_cloud_pair[3],(3,3)) @ r_pc.T).T + point_cloud_pair[4]
pc_target = trimesh.points.PointCloud(r_pc,colors=[0,255,0])
scene.add_geometry(pc_target)

scene.camera.z_far = 10
scene.show(viewer='notebook')

## Instance Segmentaion Training

In [None]:
class ManipulationDataset(torch.utils.data.Dataset):
    def __init__(self, root, transforms=None):
        self.root = root
        self.transforms = transforms
        # load all image files, sorting them to
        # ensure that they are aligned
        self.imgs = list(range(len(rgbs)))
        self.masks = list(range(len(rgbs)))
        
        self.obj_id_dict = {0: '055_baseball', 1: '056_tennis_ball', 2: '072-a_toy_airplane', 3: '019_pitcher_base', 4: '040_large_marker', 5: '021_bleach_cleanser', 6: '077_rubiks_cube', 7: '048_hammer', 8: '008_pudding_box', 9: '053_mini_soccer_ball', 10: '011_banana', 11: '006_mustard_bottle', 12: '013_apple', 13: '029_plate', 14: '035_power_drill', 15: '043_phillips_screwdriver', 16: '032_knife', 17: '042_adjustable_wrench'}
        self.obj_mask_value = {14: 0, 28: 1, 42: 2, 56: 3, 70: 4, 85: 5, 99: 6, 113: 7, 127: 8, 141: 9, 155: 10, 170: 11, 184: 12, 198: 13, 212: 14, 226: 15, 240: 16, 255: 17}
        
    def __getitem__(self, idx):
        # load images ad masks
        # img_path = os.path.join(self.root, "rgb", self.imgs[idx])
        # mask_path = os.path.join(self.root, "merged_masks", self.masks[idx])
        img = torch.Tensor(np.transpose(rgbs[self.imgs[idx]],[2,0,1]))
        # note that we haven't converted the mask to RGB,
        # because each color corresponds to a different instance
        # with 0 being background
        mask = merged_masks[self.masks[idx]]

        # instances are encoded as different colors
        obj_ids = np.unique(mask)
        # first id is the background, so remove it
        obj_ids = obj_ids[1:]

        # split the color-encoded mask into a set
        # of binary masks
        masks = mask == obj_ids[:, None, None]
        
        
        
        # get bounding box coordinates for each mask
        num_objs = len(obj_ids)
        
        labels = torch.zeros((num_objs,), dtype=torch.int64)
        
        boxes = []
        for i in range(num_objs):
            pos = np.where(masks[i])
            xmin = np.min(pos[1])
            xmax = np.max(pos[1])
            ymin = np.min(pos[0])
            ymax = np.max(pos[0])
            boxes.append([xmin, ymin, xmax, ymax])

            labels[i] = self.obj_mask_value[obj_ids[i]] + 1

        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        
        masks = torch.as_tensor(masks, dtype=torch.uint8)

        image_id = torch.tensor([idx])
        area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
        
        iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["masks"] = masks
        target["image_id"] = image_id
        target["area"] = area
        target["iscrowd"] = iscrowd

        if self.transforms is not None:
            img, target = self.transforms(img, target)
        
        img = img.to(torch.float32)/255.0
        
        Normalizer = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        
        return Normalizer(img), target

    def __len__(self):
        return len(self.imgs)

In [None]:
def build_model(num_classes):
    # load an instance segmentation model pre-trained on COCO
    model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)

    # get the number of input features for the classifier
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # replace the pre-trained head with a new one
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)

    # Stop here if you are fine-tunning Faster-RCNN

    # now get the number of input features for the mask classifier
    in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
    hidden_layer = 256
    # and replace the mask predictor with a new one
    model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask,
                                                       hidden_layer,
                                                       num_classes)

    return model

In [None]:
train = list(range(900))
test = list(range(900,1000))

In [None]:
from engine import train_one_epoch, evaluate
import utils
import transforms as T


def get_transform(train):
    transforms = []
    # converts the image, a PIL image, into a PyTorch Tensor
    if train:
        # during training, randomly flip the training images
        # and ground-truth for data augmentation
        transforms.append(T.RandomHorizontalFlip(0.5))
    
    return T.Compose(transforms)


dataset = ManipulationDataset('../manipulation/dataset/', get_transform(train=True))
dataset_test = ManipulationDataset('../manipulation/dataset/', get_transform(train=False))

torch.manual_seed(1)
dataset = torch.utils.data.Subset(dataset, train)
dataset_test = torch.utils.data.Subset(dataset_test, test)

data_loader = torch.utils.data.DataLoader(
    dataset, batch_size=16, shuffle=True, num_workers=4,
    collate_fn=utils.collate_fn)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=16, shuffle=False, num_workers=4,
    collate_fn=utils.collate_fn)

In [None]:
device = torch.device(device) if torch.cuda.is_available() else torch.device('cpu')

num_classes = 19

model = build_model(num_classes)
model.to(device)

params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.Adam(params, lr=1e-3)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                               step_size=30,
                                               gamma=0.1)

In [None]:
# number of epochs
if os.path.exists('mask-rcnn-custom-data.pt'):
    model = torch.load('mask-rcnn-custom-data.pt')
else:
    num_epochs = 100

    for epoch in range(num_epochs):
        train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=100000, epoch_length=int(split/16))
        lr_scheduler.step()
        evaluate(model, data_loader_test, device=device)

In [None]:
model.eval()
CLASS_NAMES = ['__background__', 
               '055_baseball',
               '056_tennis_ball',
               '072-a_toy_airplane',
               '019_pitcher_base',
               '040_large_marker',
               '021_bleach_cleanser',
               '077_rubiks_cube',
               '048_hammer',
               '008_pudding_box',
               '053_mini_soccer_ball',
               '011_banana',
               '006_mustard_bottle',
               '013_apple',
               '029_plate',
               '035_power_drill',
               '043_phillips_screwdriver',
               '032_knife',
               '042_adjustable_wrench']
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

In [None]:
def get_coloured_mask(mask):
    """
    random_colour_masks
      parameters:
        - image - predicted masks
      method:
        - the masks of each predicted object is given random colour for visualization
    """
    colours = [[0, 255, 0],[0, 0, 255],[255, 0, 0],[0, 255, 255],[255, 255, 0],[255, 0, 255],[80, 70, 180],[250, 80, 190],[245, 145, 50],[70, 150, 250],[50, 190, 190]]
    r = np.zeros_like(mask).astype(np.uint8)
    g = np.zeros_like(mask).astype(np.uint8)
    b = np.zeros_like(mask).astype(np.uint8)
    r[mask == 1], g[mask == 1], b[mask == 1] = colours[random.randrange(0,10)]
    coloured_mask = np.stack([r, g, b], axis=2)
    return coloured_mask

def get_prediction(img, confidence):
    """
    get_prediction
      parameters:
        - img_path - path of the input image
        - confidence - threshold to keep the prediction or not
      method:
        - Image is obtained from the image path
        - the image is converted to image tensor using PyTorch's Transforms
        - image is passed through the model to get the predictions
        - masks, classes and bounding boxes are obtained from the model and soft masks are made binary(0 or 1) on masks
          ie: eg. segment of cat is made 1 and rest of the image is made 0
    
    """

    img = img.to(device)
    pred = model([img])
    pred_score = list(pred[0]['scores'].detach().cpu().numpy())
    pred_t = [pred_score.index(x) for x in pred_score if x>confidence][-1]
    masks = (pred[0]['masks']>0.5).squeeze().detach().cpu().numpy()
    pred_class = [CLASS_NAMES[i] for i in list(pred[0]['labels'].cpu().numpy())]
    pred_boxes = [[(i[0], i[1]), (i[2], i[3])] for i in list(pred[0]['boxes'].detach().cpu().numpy())]
    masks = masks[:pred_t+1]
    pred_boxes = pred_boxes[:pred_t+1]
    pred_class = pred_class[:pred_t+1]
    return masks, pred_boxes, pred_class

def segment_instance(img, rgb_img, confidence=0.5, rect_th=2, text_size=0.75, text_th=2):
    """
    segment_instance
      parameters:
        - img_path - path to input image
        - confidence- confidence to keep the prediction or not
        - rect_th - rect thickness
        - text_size
        - text_th - text thickness
      method:
        - prediction is obtained by get_prediction
        - each mask is given random color
        - each mask is added to the image in the ration 1:0.8 with opencv
        - final output is displayed
    """
    masks, boxes, pred_cls = get_prediction(img, confidence)
    img = rgb_img
    for i in range(len(masks)):
        rgb_mask = get_coloured_mask(masks[i])
        img = cv2.addWeighted(img, 1, rgb_mask, 0.5, 0)
        cv2.rectangle(img, [int(boxes[i][0][0]),int(boxes[i][0][1])],[int(boxes[i][1][0]),int(boxes[i][1][1])],color=(0, 255, 0), thickness=rect_th)
        cv2.putText(img,pred_cls[i], [int(boxes[i][0][0]),int(boxes[i][0][1])], cv2.FONT_HERSHEY_SIMPLEX, text_size, (0,255,0),thickness=text_th)
    plt.figure(figsize=(10,15))
    plt.imshow(img)
    plt.xticks([])
    plt.yticks([])
    plt.show()

In [None]:
rand_i = np.random.choice(len(test))
segment_instance(dataset_test[rand_i][0],rgbs[test[rand_i]],confidence=0.7)

## Pose Prediction Model

In [None]:
from typing import Optional, Any, Union, Callable

class TransformerEncoderLayer_NoNorm(nn.Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of the intermediate layer, can be a string
            ("relu" or "gelu") or a unary callable. Default: relu
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
        norm_first: if ``True``, layer norm is done prior to attention and feedforward
            operations, respectively. Otherwise it's done after. Default: ``False`` (after).

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)

    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)

    Fast path:
        forward() will use a special optimized implementation if all of the following
        conditions are met:

        - Either autograd is disabled (using ``torch.inference_mode`` or ``torch.no_grad``) or no tensor
          argument ``requires_grad``
        - training is disabled (using ``.eval()``)
        - batch_first is ``True`` and the input is batched (i.e., ``src.dim() == 3``)
        - activation is one of: ``"relu"``, ``"gelu"``, ``torch.functional.relu``, or ``torch.functional.gelu``
        - at most one of ``src_mask`` and ``src_key_padding_mask`` is passed
        - if src is a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_, neither ``src_mask``
          nor ``src_key_padding_mask`` is passed
        - the two ``LayerNorm`` instances have a consistent ``eps`` value (this will naturally be the case
          unless the caller has manually modified one without modifying the other)

        If the optimized implementation is in use, a
        `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ can be
        passed for ``src`` to represent padding more efficiently than using a padding
        mask. In this case, a `NestedTensor <https://pytorch.org/docs/stable/nested.html>`_ will be
        returned, and an additional speedup proportional to the fraction of the input that
        is padding can be expected.
    """
    __constants__ = ['batch_first', 'norm_first']

    def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1,
                 activation: Union[str, Callable[[torch.Tensor], torch.Tensor]] = F.relu,
                 layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = False,
                 device=None, dtype=None) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(TransformerEncoderLayer_NoNorm, self).__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm_first = norm_first
#         self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
#         self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        # Legacy string support for activation function.
        if isinstance(activation, str):
            activation = _get_activation_fn(activation)

        # We can't test self.activation in forward() in TorchScript,
        # so stash some information about it instead.
        if activation is F.relu or isinstance(activation, torch.nn.ReLU):
            self.activation_relu_or_gelu = 1
        elif activation is F.gelu or isinstance(activation, torch.nn.GELU):
            self.activation_relu_or_gelu = 2
        else:
            self.activation_relu_or_gelu = 0
        self.activation = activation

    def __setstate__(self, state):
        super(TransformerEncoderLayer_NoNorm, self).__setstate__(state)
        if not hasattr(self, 'activation'):
            self.activation = F.relu


    def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None,
                src_key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """

        if src_key_padding_mask is not None:
            _skpm_dtype = src_key_padding_mask.dtype
            if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
                raise AssertionError(
                    "only bool and floating types of key_padding_mask are supported")
        # see Fig. 1 of https://arxiv.org/pdf/2002.04745v1.pdf
        why_not_sparsity_fast_path = ''
        if not src.dim() == 3:
            why_not_sparsity_fast_path = f"input not batched; expected src.dim() of 3 but got {src.dim()}"
        elif self.training:
            why_not_sparsity_fast_path = "training is enabled"
        elif not self.self_attn.batch_first :
            why_not_sparsity_fast_path = "self_attn.batch_first was not True"
        elif not self.self_attn._qkv_same_embed_dim :
            why_not_sparsity_fast_path = "self_attn._qkv_same_embed_dim was not True"
        elif not self.activation_relu_or_gelu:
            why_not_sparsity_fast_path = "activation_relu_or_gelu was not True"
        elif not (self.norm1.eps == self.norm2.eps):
            why_not_sparsity_fast_path = "norm1.eps is not equal to norm2.eps"
        elif src_mask is not None:
            why_not_sparsity_fast_path = "src_mask is not supported for fastpath"
        elif src.is_nested and src_key_padding_mask is not None:
            why_not_sparsity_fast_path = "src_key_padding_mask is not supported with NestedTensor input for fastpath"
        elif self.self_attn.num_heads % 2 == 1:
            why_not_sparsity_fast_path = "num_head is odd"
        elif torch.is_autocast_enabled():
            why_not_sparsity_fast_path = "autocast is enabled"

        if not why_not_sparsity_fast_path:
            tensor_args = (
                src,
                self.self_attn.in_proj_weight,
                self.self_attn.in_proj_bias,
                self.self_attn.out_proj.weight,
                self.self_attn.out_proj.bias,
#                 self.norm1.weight,
#                 self.norm1.bias,
#                 self.norm2.weight,
#                 self.norm2.bias,
                self.linear1.weight,
                self.linear1.bias,
                self.linear2.weight,
                self.linear2.bias,
            )

            # We have to use list comprehensions below because TorchScript does not support
            # generator expressions.
            if torch.overrides.has_torch_function(tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument has_torch_function"
            elif not all((x.is_cuda or 'cpu' in str(x.device)) for x in tensor_args):
                why_not_sparsity_fast_path = "some Tensor argument is neither CUDA nor CPU"
            elif torch.is_grad_enabled() and any(x.requires_grad for x in tensor_args):
                why_not_sparsity_fast_path = ("grad is enabled and at least one of query or the "
                                              "input/output projection weights or biases requires_grad")

            if not why_not_sparsity_fast_path:
                return torch._transformer_encoder_layer_fwd(
                    src,
                    self.self_attn.embed_dim,
                    self.self_attn.num_heads,
                    self.self_attn.in_proj_weight,
                    self.self_attn.in_proj_bias,
                    self.self_attn.out_proj.weight,
                    self.self_attn.out_proj.bias,
                    self.activation_relu_or_gelu == 2,
                    self.norm_first,
                    self.norm1.eps,
                    self.norm1.weight,
                    self.norm1.bias,
                    self.norm2.weight,
                    self.norm2.bias,
                    self.linear1.weight,
                    self.linear1.bias,
                    self.linear2.weight,
                    self.linear2.bias,
                    # TODO: if src_mask and src_key_padding_mask merge to single 4-dim mask
                    src_mask if src_mask is not None else src_key_padding_mask,
                    1 if src_key_padding_mask is not None else
                    0 if src_mask is not None else
                    None,
                )


        x = src
        if self.norm_first:
            x = x + self._sa_block(x, src_mask, src_key_padding_mask)
            x = x + self._ff_block(x)
        else:
            x = x + self._sa_block(x, src_mask, src_key_padding_mask)
            x = x + self._ff_block(x)

        return x


    # self-attention block
    def _sa_block(self, x: torch.Tensor,
                  attn_mask: Optional[torch.Tensor], key_padding_mask: Optional[torch.Tensor]) -> torch.Tensor:
        x = self.self_attn(x, x, x,
                           attn_mask=attn_mask,
                           key_padding_mask=key_padding_mask,
                           need_weights=False)[0]
        return self.dropout1(x)

    # feed forward block
    def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear2(self.dropout(self.activation(self.linear1(x))))
        return self.dropout2(x)


In [None]:
#Identical to DenseFusion
from pspnet import PSPNet
#Identical to DenseFusion
psp_models = {
    'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
    'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
    'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
    'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'),
    'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152')
}
#Identical to DenseFusion

# New
class ModifiedResnet(nn.Module):

    def __init__(self, usegpu=True, large=False):
        super(ModifiedResnet, self).__init__()
        if large:
            self.model = psp_models['resnet50'.lower()]()
        else:
            self.model = psp_models['resnet18'.lower()]()
        self.model = nn.DataParallel(self.model)

    def forward(self, x):
        x = self.model(x)
        return x

class TransformationNet(nn.Module):

    def __init__(self, input_dim, output_dim):
        super(TransformationNet, self).__init__()
        self.output_dim = output_dim
        self.in1 = nn.Linear(input_dim, 64)
        self.bn_1 = nn.LayerNorm(64)
        
        self.in2 = nn.Linear(64, 128)
        self.bn_2 = nn.LayerNorm(128)
        
        self.in3 = nn.Linear(128, 128)
        self.bn_3 = nn.LayerNorm(128)
        
        self.in4 = nn.Linear(128, 256)
        self.bn_4 = nn.LayerNorm(256)
        
        self.in5 = nn.Linear(256, 512)
        self.bn_5 = nn.LayerNorm(512)

        self.fc_1 = nn.Linear(512, 512)
        self.bn_6 = nn.LayerNorm(512)
        self.fc_2 = nn.Linear(512, 512)
        self.bn_7 = nn.LayerNorm(512)
        self.fc_3 = nn.Linear(512, self.output_dim*self.output_dim)

    def forward(self, x):
        num_points = x.shape[1]
        # x = x.transpose(2, 1)
        x = F.relu(self.bn_1(self.in1(x)))
        x = F.relu(self.bn_2(self.in2(x)))
        x = F.relu(self.bn_3(self.in3(x)))
        x = F.relu(self.bn_4(self.in4(x)))
        x = F.relu(self.bn_5(self.in5(x)))
        x = x.transpose(2, 1)
        
        x = nn.MaxPool1d(num_points)(x)
        x = x.view(-1, 512)

        x = F.relu(self.bn_6(self.fc_1(x)))
        x = F.relu(self.bn_7(self.fc_2(x)))
        x = self.fc_3(x)

        identity_matrix = torch.eye(self.output_dim)
        if torch.cuda.is_available():
            identity_matrix = identity_matrix.cuda()
        x = x.view(-1, self.output_dim, self.output_dim) + identity_matrix
        return x


class BasePointNet(nn.Module):

    def __init__(self, point_dimension):
        super(BasePointNet, self).__init__()
        self.input_transform = TransformationNet(input_dim=point_dimension, output_dim=point_dimension)
        self.feature_transform = TransformationNet(input_dim=64, output_dim=64)
        
        self.conv_1 = nn.Linear(point_dimension, 64)
        self.conv_2 = nn.Linear(64, 64)
        self.conv_3 = nn.Linear(64, 128)
        # self.conv_4 = nn.Linear(128, 128)
        self.conv_5 = nn.Linear(128, 256)
        # self.conv_6 = nn.Linear(256, 256)
        self.conv_7 = nn.Linear(256, 1024)

        self.bn_1 = nn.LayerNorm(64)
        self.bn_2 = nn.LayerNorm(64)
        self.bn_3 = nn.LayerNorm(128)
        # self.bn_4 = nn.LayerNorm(128)
        self.bn_5 = nn.LayerNorm(256)
        # self.bn_6 = nn.LayerNorm(256)
        self.bn_7 = nn.LayerNorm(1024)
        

    def forward(self, x, plot=False):
        num_points = x.shape[1]
        
        # input_transform = self.input_transform(x) # T-Net tensor [batch, 3, 3]
        # x = torch.bmm(x, input_transform) # Batch matrix-matrix product 
        # x = x.transpose(2, 1)
        tnet_out=x.cpu().detach().numpy()
        feature_transform = torch.zeros([1,2,2]).to(device)
        x = F.relu(self.bn_1(self.conv_1(x)))
        x = F.relu(self.bn_2(self.conv_2(x)))
        # x = x.transpose(2, 1)

        # feature_transform = self.feature_transform(x) # T-Net tensor [batch, 64, 64]
        # x = torch.bmm(x, feature_transform)
        x_perp = x
        # x = x.transpose(2, 1)
        x = F.relu(self.bn_3(self.conv_3(x)))
        # x = F.relu(self.bn_4(self.conv_4(x)))
        x = F.relu(self.bn_5(self.conv_5(x)))
        # x = F.relu(self.bn_6(self.conv_6(x)))
        x = F.relu(self.bn_7(self.conv_7(x)))
        x = x.transpose(2, 1)
        x_g, ix = nn.MaxPool1d(num_points, return_indices=True)(x)  # max-pooling
        x_g = x_g.view(-1, 1024)  # global feature vector 
        
        return x_g, x_perp, feature_transform, tnet_out, ix

class SegmentaionPointNet(nn.Module):

    def __init__(self, out_dim=128, dropout=0.3, point_dimension=6):
        super(SegmentaionPointNet, self).__init__()
        self.base_pointnet = BasePointNet(point_dimension=point_dimension)
        
        self.fc_1 = nn.Linear(1088, 512)
        self.fc_2 = nn.Linear(512, 512)
        self.fc_3 = nn.Linear(512, out_dim)

        self.bn_1 = nn.LayerNorm(512)
        self.bn_2 = nn.LayerNorm(512)

        # self.dropout_1 = nn.Dropout(dropout)

    def forward(self, x):
        x_g, x, feature_transform, tnet_out, ix_maxpool = self.base_pointnet(x)
        x_g = torch.unsqueeze(x_g,1)
        x_g = torch.tile(x_g,[1,x.shape[1],1])
        x = torch.cat([x,x_g],-1)
        x = F.relu(self.bn_1(self.fc_1(x)))
        x = F.relu(self.bn_2(self.fc_2(x)))
        # x = self.dropout_1(x)

        return self.fc_3(x), feature_transform, tnet_out, ix_maxpool
    
class ClassificationPointNet(nn.Module):
    
    def __init__(self, out_dim, dropout=0.3, point_dimension=6):
        super(ClassificationPointNet, self).__init__()
        self.base_pointnet = BasePointNet(point_dimension=point_dimension)
        
        self.fc_1 = nn.Linear(1024, 512)
        self.fc_2 = nn.Linear(512, 512)
        self.fc_3 = nn.Linear(512, out_dim)

        self.bn_1 = nn.LayerNorm(512)
        self.bn_2 = nn.LayerNorm(512)

        # self.dropout_1 = nn.Dropout(dropout)

    def forward(self, x):
        x, _, feature_transform, tnet_out, ix_maxpool = self.base_pointnet(x)

        x = F.relu(self.bn_1(self.fc_1(x)))
        x = F.relu(self.bn_2(self.fc_2(x)))
        # x = self.dropout_1(x)

        return self.fc_3(x), feature_transform, tnet_out, ix_maxpool
    
class PointCloud_Matching(nn.Module):

    def __init__(self):
        super(PointCloud_Matching, self).__init__()
        self.source_point_net = ClassificationPointNet(512)
        self.target_point_net = ClassificationPointNet(512)
        
        self.fc_1 = nn.Linear(1024, 512)
        self.fc_2 = nn.Linear(512, 512)
        self.fc_3 = nn.Linear(512, 3)
        self.fc_4 = nn.Linear(512, 3)

        self.bn_1 = nn.LayerNorm(512)
        self.bn_2 = nn.LayerNorm(512)

    def forward(self, x):
        source_pcs, target_pcs = x
        
        s_x, s_ft, s_tn, _ = self.source_point_net(source_pcs)
        t_x, t_ft, t_tn, _ = self.target_point_net(target_pcs)
        
        x = torch.cat([s_x,t_x],dim=-1)
        
        x = F.relu(self.bn_1(self.fc_1(x)))
        x = F.relu(self.bn_2(self.fc_2(x)))
        
        R = torch.tanh(self.fc_3(x)) * np.pi
        t = self.fc_4(x)
        
        return R, t, s_ft, t_ft
    
class PointCloud_Matching_pixel_wise(nn.Module):

    def __init__(self):
        super(PointCloud_Matching_pixel_wise, self).__init__()
        self.source_point_net = SegmentaionPointNet(128,point_dimension=6)
        self.target_point_net = ClassificationPointNet(64,point_dimension=6)
        self.conv = ModifiedResnet()
        
        self.fc_1 = nn.Linear(128+64, 32)
        self.fc_2 = nn.Linear(64, 256)
        self.fc_3 = nn.Linear(256, 32)

        self.bn_1 = nn.LayerNorm(32)
        self.bn_2 = nn.LayerNorm(256)
        self.bn_3 = nn.LayerNorm(32)
        
        self.r_fc_1 = nn.Linear(3*32, 512)
        self.r_bn_1 = nn.LayerNorm(512)
        self.r_fc_2 = nn.Linear(512, 512)
        self.r_bn_2 = nn.LayerNorm(512)
        self.r_fc_3 = nn.Linear(512, 3)
        
        
        self.t_fc_1 = nn.Linear(3*32, 512)
        self.t_bn_1 = nn.LayerNorm(512)
        self.t_fc_2 = nn.Linear(512, 512)
        self.t_bn_2 = nn.LayerNorm(512)
        self.t_fc_3 = nn.Linear(512, 3)
        
        self.c_fc_1 = nn.Linear(3*32, 512)
        self.c_bn_1 = nn.LayerNorm(512)
        self.c_fc_2 = nn.Linear(512, 512)
        self.c_bn_2 = nn.LayerNorm(512)
        self.c_fc_3 = nn.Linear(512, 1)
        

    def forward(self, x):
        source_pcs, target_pcs, rgb, mask = x
        
        per_pixel = self.conv(rgb)
        per_pixel = torch.unsqueeze(torch.permute(per_pixel,[0,2,3,1])[0][mask],0)
        
        s_x, s_ft, s_tn, _ = self.source_point_net(source_pcs)
        t_x, t_ft, t_tn, _ = self.target_point_net(target_pcs)
        
        t_x = torch.unsqueeze(t_x,1)
        t_x = torch.tile(t_x,[1,s_x.shape[1],1])
        
        x = torch.cat([s_x,t_x],dim=-1)
        x = F.relu(self.bn_1(self.fc_1(x)))
        
        x_perp = torch.cat([x,per_pixel],-1)
        
        x = F.relu(self.bn_2(self.fc_2(x_perp)))
        x = F.relu(self.bn_3(self.fc_3(x)))
        
        num_points = x.shape[1]
        x_g = nn.AvgPool1d(num_points)(x.transpose(2, 1)).transpose(2, 1)
        
        # x_g = torch.unsqueeze(x_g,1)
        x_g = torch.tile(x_g,[1,x_perp.shape[1],1])
        
        x_final = torch.cat([x_perp,x_g],-1)
        
        R =  F.relu(self.r_bn_1(self.r_fc_1(x_final)))
        R =  F.relu(self.r_bn_2(self.r_fc_2(R)))
        R =  torch.tanh(self.r_fc_3(R)) * np.pi
        
        t =  F.relu(self.t_bn_1(self.t_fc_1(x_final)))
        t =  F.relu(self.t_bn_2(self.t_fc_2(t)))
        t =  self.t_fc_3(t)
        
        c =  F.relu(self.c_bn_1(self.c_fc_1(x_final)))
        c =  F.relu(self.c_bn_2(self.c_fc_2(c)))
        c =  self.c_fc_3(c)
        
        return R, t, c, s_ft, t_ft

class adaptive_layer_norm(nn.Module):

    def __init__(self, channels, condition_size):
        super(adaptive_layer_norm, self).__init__()
        
        self.norm = nn.LayerNorm(channels)
        
        self.beta = nn.Linear(condition_size,channels)
        self.gamma = nn.Linear(condition_size,channels)
     
    def forward(self, x, c):
        x = self.norm(x)
        
        gamma = self.gamma(c)
        beta = self.beta(c)
        
        x = x + x * gamma + beta
        
        return x

class PointCloud_Matching_Graph(nn.Module):

    def __init__(self):
        super(PointCloud_Matching_Graph, self).__init__()
        self.source_gnn = EncoderProcessorDecoder(6,1,128,128,2,True,5,128)
        self.target_gnn = EncoderProcessorDecoder(6,1,128,128,2,True,5,128)
        
        self.fc_1 = nn.Linear(128, 256)
        self.fc_2 = nn.Linear(256, 512)
        self.fc_3 = nn.Linear(512, 512)
        
        self.r_fc_1 = nn.Linear(512, 512)
        self.r_fc_2 = nn.Linear(512, 512)
        self.r_fc_3 = nn.Linear(512, 3)
        
        self.t_fc_1 = nn.Linear(512, 512)
        self.t_fc_2 = nn.Linear(512, 512)
        self.t_fc_3 = nn.Linear(512, 3)
        
        self.aln1 = adaptive_layer_norm(256,128)
        self.aln2 = adaptive_layer_norm(512,128)
        self.aln3 = adaptive_layer_norm(512,128)
        
        self.raln1 = adaptive_layer_norm(512,128)
        self.raln2 = adaptive_layer_norm(512,128)
        
        self.taln1 = adaptive_layer_norm(512,128)
        self.taln2 = adaptive_layer_norm(512,128)
        

    def forward(self, x):
        source_graph, target_graph, s_b, t_b = x
        
        
        s_x = self.source_gnn(source_graph)
        t_x = self.target_gnn(target_graph)
        
        s_x = torch_geometric.nn.global_max_pool(s_x,s_b)
        t_x = torch_geometric.nn.global_max_pool(t_x,t_b)
        
        x = F.elu(self.aln1(self.fc_1(s_x),t_x))
        x = F.elu(self.aln2(self.fc_2(x),t_x))
        x = F.elu(self.aln3(self.fc_3(x),t_x))
        
        R =  F.elu(self.raln1(self.r_fc_1(x),t_x))
        R =  F.elu(self.raln2(self.r_fc_2(R),t_x))
        R =  torch.tanh(self.r_fc_3(R)) * np.pi
        
        t =  F.elu(self.taln1(self.t_fc_1(x),t_x))
        t =  F.elu(self.taln2(self.t_fc_2(t),t_x))
        t =  self.t_fc_3(t)
        
        return R, t
    
class PoseNetFeat(nn.Module):
    def __init__(self, large = False):
        super(PoseNetFeat, self).__init__()
        if large: 
            self.conv1 = torch.nn.Conv1d(3, 256, 1)
            self.conv2 = torch.nn.Conv1d(256, 256, 1)

            self.e_conv1 = torch.nn.Conv1d(32, 256, 1)
            self.e_conv2 = torch.nn.Conv1d(256, 256, 1)

            self.conv5 = torch.nn.Conv1d(512, 512, 1)
            self.conv6 = torch.nn.Conv1d(512, 1024, 1)
        else:
            self.conv1 = torch.nn.Conv1d(3, 64, 1)
            self.conv2 = torch.nn.Conv1d(64, 128, 1)

            self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
            self.e_conv2 = torch.nn.Conv1d(64, 128, 1)

            self.conv5 = torch.nn.Conv1d(256, 512, 1)
            self.conv6 = torch.nn.Conv1d(512, 1024, 1)

    def forward(self, x, emb):
        
        num_points = x.shape[2]
        
        x = F.relu(self.conv1(x))
        emb = F.relu(self.e_conv1(emb))
        pointfeat_1 = torch.cat((x, emb), dim=1)

        x = F.relu(self.conv2(x))
        emb = F.relu(self.e_conv2(emb))
        pointfeat_2 = torch.cat((x, emb), dim=1)

        x = F.relu(self.conv5(pointfeat_2))
        x = F.relu(self.conv6(x))

        ap_x = torch.nn.AvgPool1d(num_points)(x)

        ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, num_points)
        return torch.cat([pointfeat_1, pointfeat_2, ap_x], 1) #128 + 256 + 1024

class PoseNet(nn.Module):
    def __init__(self, num_obj, large=False):
        super(PoseNet, self).__init__()
        # self.num_points = num_points
        self.cnn = ModifiedResnet().to(device)
        self.feat = PoseNetFeat(large=large)
        if large:
            self.conv1_r = torch.nn.Conv1d(2048, 1024, 1)
            self.conv1_t = torch.nn.Conv1d(2048, 1024, 1)
            self.conv1_c = torch.nn.Conv1d(2048, 1024, 1)

            self.conv2_r = torch.nn.Conv1d(1024, 512, 1)
            self.conv2_t = torch.nn.Conv1d(1024, 512, 1)
            self.conv2_c = torch.nn.Conv1d(1024, 512, 1)

            self.conv3_r = torch.nn.Conv1d(512, 512, 1)
            self.conv3_t = torch.nn.Conv1d(512, 512, 1)
            self.conv3_c = torch.nn.Conv1d(512, 512, 1)

            self.conv4_r = torch.nn.Conv1d(512, num_obj*4, 1) #quaternion
            self.conv4_t = torch.nn.Conv1d(512, num_obj*3, 1) #translation
            self.conv4_c = torch.nn.Conv1d(512, num_obj*1, 1) #confidence
        else:
            self.conv1_r = torch.nn.Conv1d(1408, 640, 1)
            self.conv1_t = torch.nn.Conv1d(1408, 640, 1)
            self.conv1_c = torch.nn.Conv1d(1408, 640, 1)

            self.conv2_r = torch.nn.Conv1d(640, 256, 1)
            self.conv2_t = torch.nn.Conv1d(640, 256, 1)
            self.conv2_c = torch.nn.Conv1d(640, 256, 1)

            self.conv3_r = torch.nn.Conv1d(256, 128, 1)
            self.conv3_t = torch.nn.Conv1d(256, 128, 1)
            self.conv3_c = torch.nn.Conv1d(256, 128, 1)

            self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1) #quaternion
            self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1) #translation
            self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1) #confidence

        self.num_obj = num_obj

    def forward(self, img, x, mask, obj):
        out_img = self.cnn(img)
        num_points = x.shape[1]
        out_img = torch.permute(out_img[0],[1,2,0])
        emb = torch.unsqueeze(out_img[mask],0).transpose(2,1).contiguous()
        
        x = x.transpose(2, 1).contiguous()
        ap_x = self.feat(x, emb)

        rx = F.relu(self.conv1_r(ap_x))
        tx = F.relu(self.conv1_t(ap_x))
        cx = F.relu(self.conv1_c(ap_x))      

        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))
        cx = F.relu(self.conv2_c(cx))

        rx = F.relu(self.conv3_r(rx))
        tx = F.relu(self.conv3_t(tx))
        cx = F.relu(self.conv3_c(cx))

        rx = self.conv4_r(rx).view(1, self.num_obj, 4, num_points)
        tx = self.conv4_t(tx).view(1, self.num_obj, 3, num_points)
        cx = self.conv4_c(cx).view(1, self.num_obj, 1, num_points)
        
        out_rx = rx[:,obj,:,:]
        out_tx = tx[:,obj,:,:]
        out_cx = cx[:,obj,:,:]
        
        out_rx = out_rx.contiguous().transpose(2, 1).contiguous()
        out_cx = out_cx.contiguous().transpose(2, 1).contiguous()
        out_tx = out_tx.contiguous().transpose(2, 1).contiguous()
        
        return out_rx, out_tx, out_cx, emb.detach()

class AttentionPooling(nn.Module):
    """
    Implementation of SelfAttentionPooling 
    Original Paper: Self-Attention Encoding and Pooling for Speaker Recognition
    https://arxiv.org/pdf/2008.01077v1.pdf
    """
    def __init__(self, input_dim):
        super(AttentionPooling, self).__init__()
        self.W = nn.Linear(input_dim, 1)
        
    def forward(self, batch_rep):
        """
        input:
            batch_rep : size (N, T, H), N: batch size, T: sequence length, H: Hidden dimension
        
        attention_weight:
            att_w : size (N, T, 1)
        
        return:
            utter_rep: size (N, H)
        """
        softmax = nn.functional.softmax
        att_w = softmax(self.W(batch_rep).squeeze(-1)).unsqueeze(-1)
        utter_rep = torch.sum(batch_rep * att_w, dim=1)

        return utter_rep
    

    
    
class PoseNet_Attn(nn.Module):
    def __init__(self, num_obj, large=False):
        super(PoseNet_Attn, self).__init__()
        # self.num_points = num_points
        self.cnn = ModifiedResnet().to(device)
        self.feat = PoseNetFeat(large=large)
        if large:
            self.conv1_r = torch.nn.Conv1d(2048, 1024, 1)
            self.conv1_t = torch.nn.Conv1d(2048, 1024, 1)

            self.conv2_r = torch.nn.Conv1d(1024, 512, 1)
            self.conv2_t = torch.nn.Conv1d(1024, 512, 1)

            self.conv3_r = torch.nn.Conv1d(512, 512, 1)
            self.conv3_t = torch.nn.Conv1d(512, 512, 1)
            
            self.r_attn_pool = AttentionPooling(512)
            self.t_attn_pool = AttentionPooling(512)
            
            self.conv4_r = torch.nn.Conv1d(512, num_obj*4, 1) #quaternion
            self.conv4_t = torch.nn.Conv1d(512, num_obj*3, 1) #translation
        else:
            
#             self.lrbl_e_r = torch.nn.Parameter(torch.randn(128))
#             self.lrbl_e_r.requires_grad = True
            
#             self.lrbl_e_t = torch.nn.Parameter(torch.randn(128))
#             self.lrbl_e_t.requires_grad = True
            
#             self.TEL_r = TransformerEncoderLayer_NoNorm(128,8)
#             self.transf_r = nn.TransformerEncoder(self.TEL_r, num_layers=6)
            
#             self.TEL_t = TransformerEncoderLayer_NoNorm(128,8)
#             self.transf_t = nn.TransformerEncoder(self.TEL_t, num_layers=6)
            
            self.conv1_r = torch.nn.Conv1d(1408, 640, 1)
            self.conv1_t = torch.nn.Conv1d(1408, 640, 1)

            self.conv2_r = torch.nn.Conv1d(640, 256, 1)
            self.conv2_t = torch.nn.Conv1d(640, 256, 1)

            self.conv3_r = torch.nn.Conv1d(256, 128, 1)
            self.conv3_t = torch.nn.Conv1d(256, 128, 1)
            
            self.r_attn_pool = AttentionPooling(128)
            self.t_attn_pool = AttentionPooling(128)

            self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1) #quaternion
            self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1) #translation

        self.num_obj = num_obj

    def forward(self, img, x, mask, obj):
        out_img = self.cnn(img)
        num_points = x.shape[1]
        out_img = torch.permute(out_img[0],[1,2,0])
        emb = torch.unsqueeze(out_img[mask],0).transpose(2,1).contiguous()
        bs = x.shape[0]
        
        x = x.transpose(2, 1).contiguous()
        ap_x = self.feat(x, emb)

        rx = F.relu(self.conv1_r(ap_x))
        tx = F.relu(self.conv1_t(ap_x))

        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))

        rx = F.relu(self.conv3_r(rx))
        tx = F.relu(self.conv3_t(tx))
        
        rx = self.r_attn_pool(rx.transpose(2,1)).unsqueeze(-1)
        tx = self.t_attn_pool(tx.transpose(2,1)).unsqueeze(-1)

#         rx = self.transf_r(torch.cat([self.lrbl_e_r.unsqueeze(0).unsqueeze(0).repeat([bs,1,1]), rx.transpose(2,1)],1))[:,0:1,:].transpose(2,1)
#         tx = self.transf_t(torch.cat([self.lrbl_e_t.unsqueeze(0).unsqueeze(0).repeat([bs,1,1]), tx.transpose(2,1)],1))[:,0:1,:].transpose(2,1)

        rx = self.conv4_r(rx).view(bs, self.num_obj, 4)
        tx = self.conv4_t(tx).view(bs, self.num_obj, 3)
        
        out_rx = rx[:,obj:obj+1,:]
        out_tx = tx[:,obj:obj+1,:]
        
        
        return out_rx, out_tx, None, emb.detach()
    
class PoseNet_Encoder(nn.Module):
    def __init__(self, large = False):
        super(PoseNet_Encoder, self).__init__()
        if large: 
            self.conv1 = torch.nn.Conv1d(3, 256, 1)
            self.conv2 = torch.nn.Conv1d(256, 256, 1)

#             self.e_conv1 = torch.nn.Conv1d(32, 256, 1)
#             self.e_conv2 = torch.nn.Conv1d(256, 256, 1)

            self.conv5 = torch.nn.Conv1d(512, 512, 1)
            self.conv6 = torch.nn.Conv1d(512, 1024, 1)
        else:
            self.conv1 = torch.nn.Conv1d(3, 64, 1)
            self.conv2 = torch.nn.Conv1d(64, 128, 1)

#             self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
#             self.e_conv2 = torch.nn.Conv1d(64, 128, 1)

            self.conv5 = torch.nn.Conv1d(128, 128, 1)
            self.conv6 = torch.nn.Conv1d(128, 128, 1)

    def forward(self, x):
        
        num_points = x.shape[2]
        
        x = F.relu(self.conv1(x))
#         emb = F.relu(self.e_conv1(emb))
#         pointfeat_1 = torch.cat((x, emb), dim=1)

        x = F.relu(self.conv2(x))
#         emb = F.relu(self.e_conv2(emb))
#         pointfeat_2 = torch.cat((x, emb), dim=1)

        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))

        ap_x = torch.nn.AvgPool1d(num_points)(x)

        ap_x = ap_x.view(-1, 128, 1)
        return ap_x    

class PoseRefineNetFeat(nn.Module):
    def __init__(self):
        super(PoseRefineNetFeat, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)

        self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
        self.e_conv2 = torch.nn.Conv1d(64, 128, 1)
        
        self.ref_conv1 = torch.nn.Conv1d(128, 128, 1)
        self.ref_conv2 = torch.nn.Conv1d(128, 128, 1)

        self.conv5 = torch.nn.Conv1d(256+128, 512, 1)
        self.conv6 = torch.nn.Conv1d(512, 1024, 1)

    def forward(self, x, emb):
        num_points = x.shape[2]
        x = F.relu(self.conv1(x))
        emb = F.relu(self.e_conv1(emb))
#         ref_e = F.relu(self.ref_conv1(ref_e)).repeat(1, 1, num_points)
        pointfeat_1 = torch.cat([x, emb], dim=1)

        x = F.relu(self.conv2(x))
        emb = F.relu(self.e_conv2(emb))
#         ref_e = F.relu(self.ref_conv2(ref_e))
        pointfeat_2 = torch.cat([x, emb], dim=1)

        pointfeat_3 = torch.cat([pointfeat_1, pointfeat_2], dim=1)

        x = F.relu(self.conv5(pointfeat_3))
        x = F.relu(self.conv6(x))

        ap_x = torch.nn.AvgPool1d(num_points)(x)

        ap_x = ap_x.view(-1, 1024)
        return ap_x

class PoseRefineNet(nn.Module):
    def __init__(self, num_obj):
        super(PoseRefineNet, self).__init__()
        self.feat = PoseRefineNetFeat()
#         self.ref_enc = PoseNet_Encoder()
        
        self.conv1_r = torch.nn.Linear(1024, 512)
        self.conv1_t = torch.nn.Linear(1024, 512)

        self.conv2_r = torch.nn.Linear(512, 128)
        self.conv2_t = torch.nn.Linear(512, 128)

        self.conv3_r = torch.nn.Linear(128, num_obj*4) #quaternion
        self.conv3_t = torch.nn.Linear(128, num_obj*3) #translation

        self.num_obj = num_obj

    def forward(self, x, emb, obj):
        
        num_points = x.shape[1]
        bs = x.shape[0]
#         ref_e = self.ref_enc(ref_x.transpose(2, 1).contiguous())
        x = x.transpose(2, 1).contiguous()
        ap_x = self.feat(x, emb)

        rx = F.relu(self.conv1_r(ap_x))
        tx = F.relu(self.conv1_t(ap_x))   

        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))

        rx = self.conv3_r(rx).view(bs, self.num_obj, 4)
        tx = self.conv3_t(tx).view(bs, self.num_obj, 3)

        out_rx = rx[:,obj,:]
        out_tx = tx[:,obj,:]

        return out_rx, out_tx

In [None]:
def get_sample_posnet(i, aug=True):
    sym = [1,2,10,13,14]
    pc,c,obj_id,R,t = point_cloud_pairs[i]
    idx = idx_list[i]
    rgb = rgbs[idx]/255.0
    depth = np.zeros(rgb.shape)
    mask = (merged_masks[idx] == obj_to_mask_value[obj_id-1])
    # rgb[np.logical_not(mask)] = 0.0
    depth = depth_imgs[idx]
    
    if aug:
        angle = 0.0
    else:
        angle = 0.0
    rgb = np.transpose(torchvision.transforms.functional.rotate(torch.Tensor(np.transpose(rgb,[2,0,1])),angle).detach().numpy(),[1,2,0])
    mask = torchvision.transforms.functional.rotate(torch.Tensor(np.transpose(np.expand_dims(mask,-1),[2,0,1]).astype(np.float32)),angle).detach().numpy()[0].astype(bool)
    depth = np.transpose(torchvision.transforms.functional.rotate(torch.Tensor(np.transpose(depth,[2,0,1])),angle).detach().numpy(),[1,2,0])
    
    if mask.sum()>500:
        ridx = np.random.choice(mask.sum(),size=mask.sum()-500,replace=False)
        i1,i2 = np.where(mask)
        mask[i1[ridx],i2[ridx]] = False
    
    pc = depth[mask]
    obj_class = obj_id - 1
    target_pc = reference_point_clouds[obj_id - 1][0]
    
    RR = roma.rotvec_to_rotmat(torch.Tensor([0,0,-angle*np.pi/180])).detach().numpy()
    pc = (RR@pc.T).T
    
    if target_pc.shape[0] > 500:
        target_pc = target_pc[np.random.choice(target_pc.shape[0],size=500, replace=False)]
        
    target = (RR@((np.reshape(R,[3,3])@target_pc.T).T + t).T).T
    
    if aug:
        pass
        rgb = rgb + np.random.uniform(low =-0.3, high=0.3,size=rgb.shape)
        rgb = np.minimum(rgb,1.0)
        rgb = np.maximum(rgb,0.0)
        # pc = pc + np.random.uniform(low =-0.002, high=0.002,size=pc.shape)
    
    return rgb, pc*1000.0 ,mask,obj_class, R, t, obj_id in sym, diameters[obj_id-1], target*1000.0, target_pc*1000.0

In [None]:
pose_model = PoseNet_Attn(18).to(device)

learning_rate = 1e-4
epochs = 100

optimizer = torch.optim.Adam(pose_model.parameters(), lr=learning_rate)
decay_stepper = torch.optim.lr_scheduler.StepLR(optimizer, 5, gamma=0.6)

w = 0.3

In [None]:
point_cloud_pairs_train = []
point_cloud_pairs_test = []
for i,pcp in enumerate(point_cloud_pairs):
    if idx_list[i] in train:
        point_cloud_pairs_train.append(i)
    elif idx_list[i] in test:
        point_cloud_pairs_test.append(i)

In [None]:
best_ADD = 0.0
for epoch in range(epochs):
    
    l_ov = 0.0
    ADD = 0
    cc = 0
    pose_model.eval()
    if epoch>0:
        for i in tqdm(point_cloud_pairs_test):
            cc+= 1

            rgb,pc,mask,obj_class, _, _, is_sym, d, target, target_pc = get_sample_posnet(i,aug=False)

            if pc.shape[0]>0:
                R,t,c,emb = pose_model(torch.Tensor(np.expand_dims(np.transpose(rgb,[2,0,1]),0)).to(device),torch.Tensor(np.expand_dims(pc,0)).to(device),mask,obj_class)
                t = t
                R = R / (torch.norm(R, dim=2).view(1, -1, 1))
                Rs = roma.unitquat_to_rotmat(R[:,0,:])
                t = t[:,0,:]
                transformed = torch.bmm(Rs,torch.Tensor(np.expand_dims(target_pc,0)).transpose(1,2).to(device)).transpose(1,2) + t
                target = torch.Tensor(target).to(device)

                if is_sym:
                    L_p = torch.mean(torch.min(torch.cdist(transformed,torch.unsqueeze(target,0)),dim=-1)[0])
                else:
                    L_p = torch.mean(torch.norm(transformed[0] - target,dim=-1),dim=-1)

                if L_p <= 0.1 * d:
                    ADD += 1

                l_ov += L_p.cpu().detach().numpy()
        if ADD >= best_ADD:
            best_ADD = ADD
            if epoch > 0:
                torch.save(pose_model,'best_custom_new.pt')
        print('Validation Lp: %.7f, ADD(-S): %.7f'%(l_ov/cc,ADD/cc))
    
    shuffle = np.random.choice(int(point_cloud_pairs.shape[0] * 0.9), size = int(point_cloud_pairs.shape[0] * 0.9),replace=False)
    random.shuffle(point_cloud_pairs_train)
    prog = tqdm(point_cloud_pairs_train)
    l_ov = 0.0
    ADD = 0
    cc = 0
    for i in prog:
        cc+= 1
        if epoch == 0 and cc<=1000:
            optimizer.param_groups[0]['lr'] = 1e-5 * 10**(cc/1000.)
        optimizer.zero_grad()
        pose_model = pose_model.train()
        
        rgb,pc,mask,obj_class, _, _, is_sym, d, target, target_pc = get_sample_posnet(i)
        if pc.shape[0]>0:
            R,t,c,emb = pose_model(torch.Tensor(np.expand_dims(np.transpose(rgb,[2,0,1]),0)).to(device),torch.Tensor(np.expand_dims(pc,0)).to(device),mask,obj_class)
            t = t
            R = R / (torch.norm(R, dim=2).view(1, -1, 1))
            Rs = roma.unitquat_to_rotmat(R)
            transformed = torch.bmm(Rs[0],torch.tile(torch.Tensor(np.expand_dims(target_pc,0)),[Rs.shape[1],1,1]).transpose(1,2).to(device)).transpose(1,2) + torch.unsqueeze(t[0],1)
            target = torch.Tensor(target).to(device)
            
            if is_sym:
                L_p = torch.cdist(transformed.reshape(-1,3),target).min(-1)[0].reshape(transformed.shape[0],transformed.shape[1]).mean(-1)
                
            else:
                target = torch.unsqueeze(target,0)
                
                L_p = torch.mean(torch.norm(transformed - target,dim=-1),dim=-1)

            if L_p[0].cpu().detach().numpy() <= 0.1 * d:
                ADD += 1
            # print(c)
            loss = torch.mean(L_p)
            # loss = torch.mean(L_p)
            l_ov += torch.min(L_p).cpu().detach().numpy()
            loss.backward()
            optimizer.step()

            prog.set_postfix_str('ADD(-S): %.7f, overall_loss: %.7f, loss: %.7f, Lp: %.7f'%(ADD/cc,l_ov/cc,loss.cpu().detach().numpy(),torch.min(L_p).cpu().detach().numpy()))
        
    decay_stepper.step()

In [None]:
refine_model = PoseRefineNet(18).to(device)

learning_rate = 1e-3
epochs = 100

optimizer_r = torch.optim.Adam(refine_model.parameters(), lr=learning_rate)
decay_stepper_r = torch.optim.lr_scheduler.StepLR(optimizer_r, 10, gamma=0.6)

pose_model = pose_model.eval()

In [None]:
best_ADD = 0.75
for epoch in range(epochs):
    l_ov = 0.0
    ADD = 0
    cc = 0
    refine_model = refine_model.eval()
    if epoch>0:
        for i in tqdm(point_cloud_pairs_test):
            cc+= 1

            rgb,pc,mask,obj_class, _, _, is_sym, d, target, target_pc = get_sample_posnet(i,aug=False)
            if pc.shape[0]>0:
                torch_rgb = torch.Tensor(np.expand_dims(np.transpose(rgb,[2,0,1]),0)).to(device)
                torch_pc = torch.Tensor(np.expand_dims(pc,0)).to(device)
                torch_target = torch.Tensor(target).to(device)

                R,t,c,emb = pose_model(torch.Tensor(np.expand_dims(np.transpose(rgb,[2,0,1]),0)).to(device),torch.Tensor(np.expand_dims(pc,0)).to(device),mask,obj_class)
                t = t
                R = R / (torch.norm(R, dim=2).view(1, -1, 1))
                Rs = roma.unitquat_to_rotmat(R)[:,0]
                Rs_i = roma.rotmat_inverse(Rs)

                new_pc = torch.bmm(Rs_i,(torch_pc.repeat(Rs_i.shape[0],1,1) - t).transpose(1,2)).transpose(1,2)

                R_r,T_r = refine_model(new_pc,emb.repeat(new_pc.shape[0],1,1),obj_class)

                R_r = R_r / (torch.norm(R_r, dim=-1).view(-1, 1))
                R_r = roma.unitquat_to_rotmat(R_r)

                R_final = torch.bmm(R_r,Rs)
                T_final = T_r + t[:,0]
                
                Rs_i = roma.rotmat_inverse(R_final)
                new_pc = torch.bmm(Rs_i,(torch_pc.repeat(Rs_i.shape[0],1,1) - T_final.unsqueeze(1)).transpose(1,2)).transpose(1,2)

                R_r,T_r = refine_model(new_pc,emb.repeat(new_pc.shape[0],1,1),obj_class)

                R_r = R_r / (torch.norm(R_r, dim=-1).view(-1, 1))
                R_r = roma.unitquat_to_rotmat(R_r)

                R_final = torch.bmm(R_r,R_final)
                T_final += T_r

                Rs_i = roma.rotmat_inverse(R_final)
                new_pc = torch.bmm(Rs_i,(torch_pc.repeat(Rs_i.shape[0],1,1) - T_final.unsqueeze(1)).transpose(1,2)).transpose(1,2)

                R_r,T_r = refine_model(new_pc,emb.repeat(new_pc.shape[0],1,1),obj_class)

                R_r = R_r / (torch.norm(R_r, dim=-1).view(-1, 1))
                R_r = roma.unitquat_to_rotmat(R_r)

                R_final = torch.bmm(R_r,R_final)
                T_final += T_r
                
                transformed = torch.bmm(R_final, torch.Tensor(target_pc.T).to(device).unsqueeze(0).repeat([R_final.shape[0],1,1])).transpose(1,2) + T_final.unsqueeze(1)

                if is_sym:
                    L_p = torch.cdist(transformed.reshape(-1,3),torch_target).min(-1)[0].reshape(transformed.shape[0],transformed.shape[1]).mean(-1)

                else:
                    L_p = torch.norm(transformed - torch_target.unsqueeze(0),dim=-1).mean(-1)

                if L_p <= 0.1 * d:
                        ADD += 1

                l_ov += L_p.cpu().detach().numpy()
            
        if ADD >= best_ADD:
            best_ADD = ADD
            if epoch > 0:
                torch.save(refine_model,'best_linemode_refine_final.pt')
        print('Validation Lp: %.7f, ADD(-S): %.7f'%(l_ov/cc,ADD/cc))
    
    random.shuffle(point_cloud_pairs_train)
    prog = tqdm(point_cloud_pairs_train)
    l_ov = 0.0
    ADD = 0
    cc = 0
    
    for i in prog:
        cc+= 1
        
        optimizer_r.zero_grad()
        refine_model = refine_model.train()
        
        rgb,pc,mask,obj_class, _, _, is_sym, d, target, target_pc = get_sample_posnet(i)
        if pc.shape[0]>0:
            torch_rgb = torch.Tensor(np.expand_dims(np.transpose(rgb,[2,0,1]),0)).to(device)
            torch_pc = torch.Tensor(np.expand_dims(pc,0)).to(device)
            torch_target = torch.Tensor(target).to(device)

            R,t,c,emb = pose_model(torch.Tensor(np.expand_dims(np.transpose(rgb,[2,0,1]),0)).to(device),torch.Tensor(np.expand_dims(pc,0)).to(device),mask,obj_class)
            t = t
            R = R / (torch.norm(R, dim=2).view(1, -1, 1))
            Rs = roma.unitquat_to_rotmat(R)[:,0]
            Rs_i = roma.rotmat_inverse(Rs)

            new_pc = torch.bmm(Rs_i,(torch_pc.repeat(Rs_i.shape[0],1,1) - t).transpose(1,2)).transpose(1,2)

            R_r,T_r = refine_model(new_pc,emb.repeat(new_pc.shape[0],1,1),obj_class)

            R_r = R_r / (torch.norm(R_r, dim=-1).view(-1, 1))
            R_r = roma.unitquat_to_rotmat(R_r)
        
            R_final = torch.bmm(R_r,Rs)
            T_final = T_r + t[:,0]
            
            Rs_i = roma.rotmat_inverse(R_final)
            new_pc = torch.bmm(Rs_i,(torch_pc.repeat(Rs_i.shape[0],1,1) - T_final.unsqueeze(1)).transpose(1,2)).transpose(1,2)

            R_r,T_r = refine_model(new_pc,emb.repeat(new_pc.shape[0],1,1),obj_class)

            R_r = R_r / (torch.norm(R_r, dim=-1).view(-1, 1))
            R_r = roma.unitquat_to_rotmat(R_r)

            R_final = torch.bmm(R_r,R_final)
            T_final += T_r

            Rs_i = roma.rotmat_inverse(R_final)
            new_pc = torch.bmm(Rs_i,(torch_pc.repeat(Rs_i.shape[0],1,1) - T_final.unsqueeze(1)).transpose(1,2)).transpose(1,2)

            R_r,T_r = refine_model(new_pc,emb.repeat(new_pc.shape[0],1,1),obj_class)

            R_r = R_r / (torch.norm(R_r, dim=-1).view(-1, 1))
            R_r = roma.unitquat_to_rotmat(R_r)

            R_final = torch.bmm(R_r,R_final)
            T_final += T_r
            
            transformed = torch.bmm(R_final, torch.Tensor(target_pc.T).to(device).unsqueeze(0).repeat([R_final.shape[0],1,1])).transpose(1,2) + T_final.unsqueeze(1)
            
            if is_sym:
                L_p = torch.cdist(transformed.reshape(-1,3),torch_target).min(-1)[0].reshape(transformed.shape[0],transformed.shape[1]).mean(-1)

            else:
                L_p = torch.norm(transformed - torch_target.unsqueeze(0),dim=-1).mean(-1)
                
            if L_p <= 0.1 * d:
                    ADD += 1
            
            l_ov += L_p.cpu().detach().numpy()
                
            loss = L_p.mean()
            
            loss.backward()
            optimizer_r.step()

            prog.set_postfix_str('ADD(-S): %.7f, overall_loss: %.7f, loss: %.7f, Lp: %.7f'%(ADD/cc,l_ov/cc,loss.cpu().detach().numpy(),torch.min(L_p).cpu().detach().numpy()))
        
    decay_stepper_r.step()