In [23]:
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

In [10]:
!gpustat

[1m[37m8d809b5da21a       [m  Tue Aug 16 08:00:55 2022  [1m[30m460.73.01[m
[36m[0][m [34mGeForce RTX 3090[m |[1m[31m 67'C[m, [1m[32m 94 %[m | [36m[1m[33m11538[m / [33m24268[m MB |
[36m[1][m [34mGeForce RTX 3090[m |[31m 29'C[m, [32m  0 %[m | [36m[1m[33m    8[m / [33m24268[m MB |
[36m[2][m [34mGeForce RTX 3090[m |[31m 44'C[m, [32m  0 %[m | [36m[1m[33m10325[m / [33m24268[m MB |
[36m[3][m [34mGeForce RTX 3090[m |[1m[31m 53'C[m, [1m[32m 54 %[m | [36m[1m[33m15031[m / [33m24268[m MB |
[36m[4][m [34mGeForce RTX 3090[m |[1m[31m 54'C[m, [1m[32m 63 %[m | [36m[1m[33m 1761[m / [33m24268[m MB |
[36m[5][m [34mGeForce RTX 3090[m |[1m[31m 56'C[m, [1m[32m 88 %[m | [36m[1m[33m23739[m / [33m24268[m MB |
[36m[6][m [34mGeForce RTX 3090[m |[1m[31m 55'C[m, [1m[32m 94 %[m | [36m[1m[33m23739[m / [33m24268[m MB |
[36m[7][m [34mGeForce RTX 3090[m |[31m 36'C[m, [32m  0 %[m | [36m[1m[33m 1305

In [11]:
available_gpus = [1, 7]
dev = None if len(available_gpus) == 0 else available_gpus[0]

In [24]:
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    # (batch_size, num_points, k)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx


def get_graph_feature(x, k, knn_only=False, disp_only=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    idx = knn(x, k=k)   # (batch_size, num_points, k)
    device = x.get_device()
    idx_base = torch.arange(0, batch_size).view(-1, 1, 1) * num_points
    idx_base = idx_base.cuda(device)
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    if knn_only:
        return feature
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    if disp_only:
        return (feature - x).permute(0, 3, 1, 2)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature      # (batch_size, 2*num_dims, num_points, k)


class Transform_Net(nn.Module):
    def __init__(self, args):
        super(Transform_Net, self).__init__()

        self.k = args.k

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn4 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn5 = nn.BatchNorm1d(256)

        self.transform = nn.Linear(256, 3*3)
        nn.init.constant_(self.transform.weight, 0)
        nn.init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        # x (B x 3 x N)
        batch_size = x.size(0)
        # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = get_graph_feature(x, k=self.k)

        # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        t = self.conv1(t)
        # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        t = self.conv2(t)
        # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)
        t = t.max(dim=-1, keepdim=False)[0]

        # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        t = self.conv3(t)
        # (batch_size, 1024, num_points) -> (batch_size, 1024)
        t = t.max(dim=-1, keepdim=False)[0]

        # (batch_size, 1024) -> (batch_size, 512)
        t = F.leaky_relu(self.bn4(self.linear1(t)), negative_slope=0.2)
        # (batch_size, 512) -> (batch_size, 256)
        t = F.leaky_relu(self.bn5(self.linear2(t)), negative_slope=0.2)

        # (batch_size, 256) -> (batch_size, 3*3)
        t = self.transform(t)
        # (batch_size, 3*3) -> (batch_size, 3, 3)
        t = t.view(batch_size, 3, 3)

        # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)
        # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)
        # (batch_size, num_points, 3) -> (batch_size, 3, num_points)
        x = x.transpose(2, 1)
        return x


class DGCNN_partseg(nn.Module):
    def __init__(self, args, seg_num_all):
        super(DGCNN_partseg, self).__init__()
        self.args = args
        self.seg_num_all = seg_num_all
        self.tnet = Transform_Net(args)
        self.k = args.k

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm1d(args.emb_dims)
        self.bn7 = nn.BatchNorm1d(64)
        self.bn8 = nn.BatchNorm1d(256)
        self.bn9 = nn.BatchNorm1d(256)
        self.bn10 = nn.BatchNorm1d(128)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, args.emb_dims, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=args.dropout)
        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   self.bn9,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp2 = nn.Dropout(p=args.dropout)
        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                    self.bn10,
                                    nn.LeakyReLU(negative_slope=0.2))
        self.conv11 = nn.Conv1d(128, self.seg_num_all,
                                kernel_size=1, bias=False)

    def forward(self, x, l):
        batch_size = x.size(0)
        num_points = x.size(2)

        x = self.tnet(x)

        # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = get_graph_feature(x, k=self.k)
        # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv1(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        x1 = x.max(dim=-1, keepdim=False)[0]

        # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = get_graph_feature(x1, k=self.k)
        # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv3(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        x2 = x.max(dim=-1, keepdim=False)[0]

        # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = get_graph_feature(x2, k=self.k)
        # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv5(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        x3 = x.max(dim=-1, keepdim=False)[0]

        # (batch_size, 64*3, num_points)
        x = torch.cat((x1, x2, x3), dim=1)

        # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = self.conv6(x)
        # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)
        x = x.max(dim=-1, keepdim=True)[0]

        # (batch_size, num_categoties, 1)
        l = l.view(batch_size, -1, 1)
        # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)
        l = self.conv7(l)

        x = torch.cat((x, l), dim=1)            # (batch_size, 1088, 1)
        # (batch_size, 1088, num_points)
        x = x.repeat(1, 1, num_points)

        # (batch_size, 1088+64*3, num_points)
        x = torch.cat((x, x1, x2, x3), dim=1)

        # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
        x = self.conv8(x)
        x = self.dp1(x)
        # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
        x = self.conv9(x)
        x = self.dp2(x)
        # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
        x = self.conv10(x)
        # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)
        x = self.conv11(x)

        return x


class DGCNN(nn.Module):
    def __init__(self, k, emb_dim):
        super(DGCNN, self).__init__()
        self.k = k
        self.emb_dim = emb_dim

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, emb_dim, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(emb_dim),
                                   nn.LeakyReLU(negative_slope=0.2))

    def forward(self, x):
        batch_size = x.size(0)
        num_points = x.size(2)

        # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = get_graph_feature(x, k=self.k)
        # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv1(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        x1 = x.max(dim=-1, keepdim=False)[0]

        # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = get_graph_feature(x1, k=self.k)
        # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv3(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        x2 = x.max(dim=-1, keepdim=False)[0]

        # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = get_graph_feature(x2, k=self.k)
        # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv5(x)
        # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        x3 = x.max(dim=-1, keepdim=False)[0]

        # (batch_size, 64*3, num_points)
        x = torch.cat((x1, x2, x3), dim=1)
        # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = self.conv6(x)
        # (batch_size, num_points, emb_dims)
        y = x.view(batch_size, num_points, self.emb_dim)

        return y


class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


d = {
    'k': 40,
    'emb_dims': 1024,
    'dropout': 0.5
}
args = dotdict(d)


In [13]:
model_old = DGCNN_partseg(args, seg_num_all=50).cuda(dev)
model_old_dist = nn.DataParallel(model_old)

In [14]:
model_old_dist.load_state_dict(torch.load("outputs/partseg/models/dgcnn.pt"))


<All keys matched successfully>

In [25]:
model = DGCNN(k=40, emb_dim=1024).cuda(dev)
model.load_state_dict(torch.load('ckpts/dgcnn.pt'))

<All keys matched successfully>

Conv2d(6, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)

In [41]:
original_layer = list(list(model_old_dist.children())[0].children())[16][0].weight
new_layer = list(model.children())[5][0].weight

In [33]:
new_layer.shape

torch.Size([64, 6, 1, 1])

In [35]:
original_layer.shape

torch.Size([64, 6, 1, 1])

In [34]:
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()


In [42]:
np.testing.assert_allclose(
    to_numpy(new_layer), to_numpy(original_layer), rtol=1e-03, atol=1e-05)
