# Gradients and GCN Features Fusion Transformer for Point Cloud Segmentation

In [1]:
!gpustat

[1m[37m8d809b5da21a       [m  Tue Aug 16 07:33:26 2022  [1m[30m460.73.01[m
[36m[0][m [34mGeForce RTX 3090[m |[1m[31m 68'C[m, [1m[32m 92 %[m | [36m[1m[33m10241[m / [33m24268[m MB |
[36m[1][m [34mGeForce RTX 3090[m |[1m[31m 59'C[m, [1m[32m 59 %[m | [36m[1m[33m13846[m / [33m24268[m MB |
[36m[2][m [34mGeForce RTX 3090[m |[31m 44'C[m, [32m  0 %[m | [36m[1m[33m10325[m / [33m24268[m MB |
[36m[3][m [34mGeForce RTX 3090[m |[1m[31m 54'C[m, [32m  0 %[m | [36m[1m[33m15031[m / [33m24268[m MB |
[36m[4][m [34mGeForce RTX 3090[m |[31m 43'C[m, [32m  0 %[m | [36m[1m[33m 1287[m / [33m24268[m MB |
[36m[5][m [34mGeForce RTX 3090[m |[1m[31m 56'C[m, [1m[32m 86 %[m | [36m[1m[33m23739[m / [33m24268[m MB |
[36m[6][m [34mGeForce RTX 3090[m |[1m[31m 55'C[m, [1m[32m 91 %[m | [36m[1m[33m23739[m / [33m24268[m MB |
[36m[7][m [34mGeForce RTX 3090[m |[31m 30'C[m, [32m  0 %[m | [36m[1m[33m    8[m 

In [2]:
available_gpus = [0]
dev = None if len(available_gpus) == 0 else available_gpus[0]

### Imports


In [3]:
import gc
import copy

import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Subset
from data import ShapeNetPart

In [4]:
def clear_mem(*objs):
    for obj in objs:
        del obj
    gc.collect()
    torch.cuda.empty_cache()

## DGCNN


In [5]:
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
    # (batch_size, num_points, k)
    idx = pairwise_distance.topk(k=k, dim=-1)[1]
    return idx


def get_graph_feature(x, k, knn_only=False, disp_only=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    idx = knn(x, k=k)   # (batch_size, num_points, k)
    device = x.get_device()
    idx_base = torch.arange(0, batch_size).view(-1, 1, 1) * num_points
    idx_base = idx_base.cuda(device)
    idx = idx + idx_base
    idx = idx.view(-1)
    _, num_dims, _ = x.size()
    # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    x = x.transpose(2, 1).contiguous()
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims)
    if knn_only:
        return feature
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    if disp_only:
        return (feature - x).permute(0, 3, 1, 2)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature      # (batch_size, 2*num_dims, num_points, k)


In [6]:
class Transform_Net(nn.Module):
    def __init__(self, k):
        super(Transform_Net, self).__init__()

        self.k = k

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn4 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn5 = nn.BatchNorm1d(256)

        self.transform = nn.Linear(256, 3*3)
        nn.init.constant_(self.transform.weight, 0)
        nn.init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        # x (B x 3 x N)
        batch_size = x.size(0)
        # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = get_graph_feature(x, k=self.k)
        
        t = self.conv1(t)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        t = self.conv2(t)                       # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        t = t.max(dim=-1, keepdim=False)[0]     # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)
        
        t = self.conv3(t)                       # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        t = t.max(dim=-1, keepdim=False)[0]     # (batch_size, 1024, num_points) -> (batch_size, 1024)
        
        t = F.leaky_relu(self.bn4(self.linear1(t)), negative_slope=0.2)     # (batch_size, 1024) -> (batch_size, 512)
        t = F.leaky_relu(self.bn5(self.linear2(t)), negative_slope=0.2)     # (batch_size, 512) -> (batch_size, 256)
        
        t = self.transform(t)                   # (batch_size, 256) -> (batch_size, 3*3)
        t = t.view(batch_size, 3, 3)            # (batch_size, 3*3) -> (batch_size, 3, 3)
        
        # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)
        # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)
        # (batch_size, num_points, 3) -> (batch_size, 3, num_points)
        x = x.transpose(2, 1)
        return x

In [7]:
class DGCNN(nn.Module):
    def __init__(self, k, emb_dim):
        super(DGCNN, self).__init__()
        self.k = k
        self.emb_dim = emb_dim
    
        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   nn.BatchNorm2d(64),
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, emb_dim, kernel_size=1, bias=False),
                                   nn.BatchNorm1d(emb_dim),
                                   nn.LeakyReLU(negative_slope=0.2))
        

    def forward(self, x):
        batch_size = x.size(0)
        num_points = x.size(2)

        x = get_graph_feature(x, k=self.k)      # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv5(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = torch.cat((x1, x2, x3), dim=1)      # (batch_size, 64*3, num_points)
        x = self.conv6(x)                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        # (batch_size, num_points, emb_dims)
        y = x.view(batch_size, num_points, self.emb_dim)
        
        return y

## Gradients & HOG 3D

HOG is not implemented yet. Will implement later after creating pipeline with SVD only.

In [8]:
def get_gradients(x, k, do_pca=False):
    '''
    x (Bx3xN) batch of point clouds
    return gradients (BxNx3): direction of maximimal variance at each point
    '''
    x_nn = get_graph_feature(x, k=k, knn_only=True)  # Bx3xN -> BxNxkx3
    if do_pca:
        _, _, v = torch.pca_lowrank(x_nn)  # BxNxkx3 -> BxNx3x3
        
    else:
        mean = x_nn.mean(dim=2).unsqueeze(dim=2)
        centered = x_nn - mean
        _, _, v = torch.linalg.svd(centered)  # BxNxkx3 -> BxNx3x3
    gradients = v[:, :, 0]  # BxNx3x3 -> BxNx3
    return gradients


In [8]:
# def get_angle_histogram(gradients, choice="cartesian"):
#     '''
#     Takes input gradients (BxNx3)
#     Computes angles in cartesian or spherical coordinates
#     Returns an array (Bx180x3 or Bx100x2) of angle frequencie
#     '''
#     assert choice in ("cartesian", "spherical")
#     # gradients = gradients.cpu()
#     batch_size = gradients.size(0)
#     if choice == "cartesian":
#         angles = torch.acos(gradients) * 180.0 / math.pi
#         freq_table = torch.zeros((batch_size, 180, 3))
#         for i in range(batch_size):
#             for j in range(3):
#                 freq_table[i, :, j] = torch.histc(
#                     angles[i, :, j], bins=180, min=0, max=180)
#     else:
#         angles = torch.empty((gradients.size(0), gradients.size(1), 2))
#         # theta
#         angles[:, :, 0] = torch.atan(gradients[:, :, 1] / gradients[:, :, 0])
#         # pi
#         angles[:, :, 1] = torch.acos(gradients[:, :, 2])
#         angles = angles * 180 / math.pi
#         # range is [-100,100]
#         freq_table = torch.zeros((batch_size, 100, 2))
#         for i in range(batch_size):
#             freq_table[i, :, 0] = torch.histc(angles[i, :, 0])
#             freq_table[i, :, 1] = torch.histc(angles[i, :, 1])
#     # get density instead of count
#     return freq_table / 2048


In [9]:
# bs = 10
# n = 2048

# train_dataset = ShapeNetPart(n, 'trainval')
# train_loader = DataLoader(train_dataset, num_workers=2, batch_size=bs,
#                           shuffle=True, drop_last=False)

# k = 20

# for data, _, _ in train_loader:
#     print(data.shape)
#     data = data.cuda().permute(0, 2, 1)
#     print(data.shape)
#     break

In [10]:
# gradients = get_gradients(data, k=20)
# hist = get_angle_histogram(gradients)
# for i in range(bs):
#     plt.figure(figsize=(10, 5))
#     plt.bar(range(180), hist[i, :, 0], alpha=0.5)
#     plt.bar(range(180), hist[i, :, 1], alpha=0.5)
#     plt.bar(range(180), hist[i, :, 2], alpha=0.5)
#     plt.legend(labels=['x', 'y', 'z'])

In [11]:
# hist = get_angle_histogram(gradients, 'spherical')
# for i in range(bs):
#     plt.figure(figsize=(10, 5))
#     plt.bar(range(100), hist[i, :, 0], alpha=0.5)
#     plt.bar(range(100), hist[i, :, 1], alpha=0.5)
#     plt.legend(labels=['theta', 'phi'])

## Encoder

**Point Transformer**

1. Decrease `emb_dims` by 1x1 conv like DETR
2. Use attention mech of Point Transformer (Hengshuang)
3. Resize to original dim
3. Design encoder of original transformer (without multihead for now)

**PointBERT**

- Patch based embeddings... (1 embedding for 1 patch)


### Point Transformer by [qq456cvb](https://github.com/qq456cvb/Point-Transformers/blob/master/models/Hengshuang/transformer.py)

In [9]:
def square_distance(src, dst):
    return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1)


def index_points(points, idx):
    raw_size = idx.size()
    idx = idx.reshape(raw_size[0], -1)
    res = torch.gather(
        points, 1, idx[..., None].expand(-1, -1, points.size(-1)))
    return res.reshape(*raw_size, -1)

In [10]:
class PointTransformerLayer(nn.Module):
    def __init__(self, d_points=256, d_model=64, k=16) -> None:
        super(PointTransformerLayer, self).__init__()

        self.k = k

        self.fc1 = nn.Linear(d_points, d_model)
        self.fc2 = nn.Linear(d_model, d_points)

        self.fc_delta = nn.Sequential(
            nn.Linear(3, d_model, bias=True),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.fc_gamma = nn.Sequential(
            nn.Linear(d_model, d_model, bias=True),
            nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        self.w_qs = nn.Linear(d_model, d_model, bias=False)
        self.w_ks = nn.Linear(d_model, d_model, bias=False)
        self.w_vs = nn.Linear(d_model, d_model, bias=False)
        
        
    def forward(self, xyz, features):
        # xyz: b x n x 3, features: b x n x f
        dists = square_distance(xyz, xyz)
        knn_idx = dists.argsort()[:, :, :self.k]  # b x n x k
        knn_xyz = index_points(xyz, knn_idx)
        
        pre = features
        x = self.fc1(features)
        q, k, v = self.w_qs(x), index_points(self.w_ks(x), knn_idx), index_points(self.w_vs(x), knn_idx)

        pos_enc = self.fc_delta(xyz[:, :, None] - knn_xyz)  # b x n x k x f
        
        attn = self.fc_gamma(q[:, :, None] - k + pos_enc)
        attn = F.softmax(attn, dim=-2)  # b x n x k x f
        attn = F.normalize(attn, p=1.0, dim=-2)
        
        res = torch.einsum('bmnf,bmnf->bmf', attn, v + pos_enc)
        res = self.fc2(res) + pre
        return res

### Encoder Layer

In [11]:
def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module)] * N)

In [12]:
# https://github.com/POSTECH-CVLab/point-transformer/blob/master/model/pointtransformer/pointtransformer_seg.py

class EncoderLayer(nn.Module):

    def __init__(self,
                 in_channels,
                 mid_channels=256,
                 out_channels=None,
                 ):
        super(EncoderLayer, self).__init__()
        # output has same dim has input
        out_channels = in_channels if out_channels is None else out_channels
        # only use for large dim inputs
        mid_channels = mid_channels if in_channels > mid_channels else in_channels

        self.scale_dim = nn.Sequential(
            nn.Conv1d(in_channels, mid_channels, 1, bias=False),
            nn.BatchNorm1d(mid_channels),
            nn.ReLU(inplace=False)
        )

        self.attention = PointTransformerLayer(d_points=mid_channels)
        self.bn = nn.BatchNorm1d(mid_channels)

        self.restore_dim = nn.Sequential(
            nn.Conv1d(mid_channels, out_channels, 1, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=False)
        )

        self.skip_conn = nn.Sequential(
            nn.LayerNorm(out_channels),
            nn.ReLU()
        )
            

    def forward(self, p, x):
        # x (B x N x C)
        # p (B x N x 3)

        bs = x.size(0)
        n = x.size(1)

        # x -> B x C x N -> B x mid_channels x N
        y = self.scale_dim(x.transpose(1, 2))
        # p (B x N x 3) & y (B x mid_channels x N) -> (B x mid_channels x N )
        y = F.relu(self.bn(self.attention(
            p, y.transpose(1, 2)).transpose(1, 2)))
        # (B x N x mid_channels) -> (B x out_channels x N) -> (B x N x out_channels)

        y = self.restore_dim(y)
        y = y.transpose(1, 2)
        # skip connection (won't work if in_channels != out_channels)
        y = y + x
        y = self.skip_conn(y)
        return y

In [13]:
class TransformerEncoder(nn.Module):
    def __init__(self, in_channels, num_layers, **kwargs):
        super(TransformerEncoder, self).__init__()
        
        self.layers = _get_clones(EncoderLayer(in_channels=in_channels, **kwargs), num_layers)
        self.norm = nn.LayerNorm(in_channels)

    def forward(self, pc, x):
        for layer in self.layers:
            x = layer(pc, x)
        y = self.norm(x)
        return y

## Fusion

In [14]:
class MultiHeadAttention(nn.Module):
    '''
    Self attention with multiple heads
    '''

    def __init__(self, d_graph, d_grads, d_k, num_heads, d_v=None,dropout=.1):
        '''
        :param d_model: Output dimensionality of the model
        :param d_k: Dimensionality of queries and keys
        :param d_v: Dimensionality of values
        :param h: Number of heads
        '''
        super(MultiHeadAttention, self).__init__()

        self.d_graph = d_graph
        self.d_grads = d_grads
        self.d_k = d_k
        self.d_v = d_k if d_v is None else d_v
        self.h = num_heads

        self.fc_q = nn.Linear(self.d_grads, self.h * self.d_k)
        self.fc_k = nn.Linear(self.d_graph, self.h * self.d_k)
        self.fc_v = nn.Linear(self.d_graph, self.h * self.d_v)
        self.fc_o = nn.Linear(self.h * self.d_v, self.d_graph)
        self.dropout = nn.Dropout(dropout)

        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
        '''
        Computes
        :param queries: Queries (b_s, nq, d_grads)
        :param keys: Keys (b_s, nk, d_graph)
        :param values: Values (b_s, nk, d_graph)
        :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
        :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
        :return:
        '''
        b_s, nq = queries.shape[:2]
        nk = keys.shape[1]

        q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3)  # (b_s, h, nq, d_k)
        k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1)  # (b_s, h, d_k, nk)
        v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3)  # (b_s, h, nk, d_v)

        att = torch.matmul(q, k) / np.sqrt(self.d_k)  # (b_s, h, nq, nk)
        if attention_weights is not None:
            att = att * attention_weights
        if attention_mask is not None:
            att = att.masked_fill(attention_mask, -np.inf)

        att = F.softmax(att, dim=-1)
        att = F.normalize(att, p=1.0, dim=-1)
        att = self.dropout(att)

        out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v)  # (b_s, nq, h*d_v)
        out = self.fc_o(out)  # (b_s, nq, d_graph)
        return out

In [15]:
class Offset_Attention(nn.Module):

    def __init__(self, in_channels_graph, in_channels_grads, 
                 mid_channels=64, out_channels=None, 
                 num_heads=8, dropout=0.1):
        super(Offset_Attention, self).__init__()

        if out_channels is None:
            out_channels = in_channels_graph

        self.attention = MultiHeadAttention(num_heads=num_heads, d_graph=in_channels_graph,
                                            d_grads=in_channels_grads, d_k=mid_channels, 
                                            dropout=dropout)
        
        self.lbr = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.ReLU(inplace=False)
            )

    def forward(self, grads, graph):
        # grads (B x N x 3)
        # graph (B x N x F)
        attn_output = self.attention(queries=grads, keys=graph, values=graph)
        attn_output = attn_output - graph
        y = self.lbr(attn_output.transpose(1, 2)).transpose(1, 2)
        return y

## Decoder

- Pointformer LGT block "*adopts a multi-scale cross-attention module to build connections between local features ... and global features*"

In [39]:
class Decoder(nn.Module):
    def __init__(self, d_model, nheads, nlayers, n_classes) -> None:
        super(Decoder, self).__init__()

        layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nheads,
                                           batch_first=True)
        self.decoder = nn.TransformerDecoder(layer, num_layers=nlayers)
        

    def forward(self, tgt, memory):
        # tgt (B x N x F) memory (B x N x F)
        tgt = self.decoder(tgt, memory)  # (B x N x F)
        return tgt

## Classification Layer

In [40]:
class Classifier(nn.Module):
    def __init__(self, d_model, n_classes):
        super(Classifier, self).__init__()

        self.clf = nn.Sequential(
            nn.Conv1d(d_model, d_model // 8, 1, bias=False),
            nn.BatchNorm1d(d_model // 8),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv1d(d_model // 8, d_model // 64, 1, bias=False),
            nn.BatchNorm1d(d_model // 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Conv1d(d_model // 64, n_classes, 1, bias=False)
        )

    def forward(self, x):
        # x (B x N x F)
        return self.clf(x.transpose(1, 2))  # (B x n_classes x N)


## Test Arch

In [41]:
class Net(nn.Module):
    def __init__(self, k, emb_dim, nlayers, nclasses):
        super(Net, self).__init__()
        # number of nearest neighbors
        self.k = k
        # transform to canonical representation
        self.tnet = Transform_Net(k=k)
        self.tnet.load_state_dict(torch.load("ckpts/tnet.pt"))
        # get graph features
        self.dgcnn = DGCNN(k=k, emb_dim=emb_dim)
        self.dgcnn.load_state_dict(torch.load("ckpts/dgcnn.pt"))
        # produce attn_map from graph features
        self.graph_encoder = TransformerEncoder(
            in_channels=emb_dim, num_layers=nlayers)
        # produce attn_map from gradients
        self.gradients_encoder = TransformerEncoder(
            in_channels=3, num_layers=nlayers)
        # fuse gradients and graph attention
        self.fusion_net = Offset_Attention(in_channels_graph=emb_dim, in_channels_grads=3)
        # get segmap
        # self.decoder = Decoder(d_model=emb_dim, nheads=4,
        #                        nlayers=nlayers, n_classes=50)
        self.clf = Classifier(d_model=emb_dim, n_classes=nclasses)

    def forward(self, x):
        # x = B x 3 x N
        pcd = self.tnet(x) # B x 3 x N
        # B x 3 x N -> B x N x emb_dim
        graph_ftrs = self.dgcnn(pcd)
        # B x N x 3
        gradient_ftrs = get_gradients(x, k=self.k)
        # B x 3 x N -> B x N x 3
        pcd = pcd.transpose(1, 2)
        # (B x N x 3) & (B x N x emb_dim) -> B x N x emb_dim
        graph_attn_map = self.graph_encoder(pcd, graph_ftrs)
        # (B x N x 3) & (B x N x 3) -> B x N x 3
        gradients_attn_map = self.gradients_encoder(pcd, gradient_ftrs)
        # (B x N x 3) & (B x N x emb_dim) -> B x N x emb_dim
        fused_attn_map = self.fusion_net(gradients_attn_map, graph_attn_map)
        # (B x N x emb_dim) & (B x N x emb_dim) -> B x n_classes x N
        # segmap = self.decoder(graph_ftrs, fused_attn_map)
        scores = self.clf(fused_attn_map)
        return scores


In [42]:
bs = 4
emb_dim = 1024
k = 16
dev = 4

In [44]:
model = Net(k=k, emb_dim=emb_dim, nlayers=4, nclasses=50).cuda(dev)

In [45]:
sum(p.numel() for p in model.parameters())

4402428

In [24]:
!gpustat

[1m[37ma1aa4d27e2e8       [m  Fri Aug 12 16:14:06 2022  [1m[30m460.73.01[m
[36m[0][m [34mGeForce RTX 3090[m |[1m[31m 67'C[m, [1m[32m 95 %[m | [36m[1m[33m13287[m / [33m24268[m MB |
[36m[1][m [34mGeForce RTX 3090[m |[31m 49'C[m, [32m  0 %[m | [36m[1m[33m11669[m / [33m24268[m MB |
[36m[2][m [34mGeForce RTX 3090[m |[31m 45'C[m, [32m  0 %[m | [36m[1m[33m10325[m / [33m24268[m MB |
[36m[3][m [34mGeForce RTX 3090[m |[1m[31m 63'C[m, [1m[32m 82 %[m | [36m[1m[33m21144[m / [33m24268[m MB |
[36m[4][m [34mGeForce RTX 3090[m |[31m 44'C[m, [32m  0 %[m | [36m[1m[33m 1287[m / [33m24268[m MB |
[36m[5][m [34mGeForce RTX 3090[m |[1m[31m 60'C[m, [1m[32m 40 %[m | [36m[1m[33m23833[m / [33m24268[m MB |
[36m[6][m [34mGeForce RTX 3090[m |[1m[31m 64'C[m, [1m[32m 91 %[m | [36m[1m[33m17893[m / [33m24268[m MB |
[36m[7][m [34mGeForce RTX 3090[m |[1m[31m 65'C[m, [1m[32m 82 %[m | [36m[1m[33m19527

In [25]:
# x = torch.randn((2, 2048, 3), device=torch.device(dev)).transpose(1, 2)
# y = model(x)
# y.shape

## Ignite Testing

In [26]:
from ignite.engine import Engine, Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import ConfusionMatrix, IoU, Loss, mIoU
from ignite.handlers import ModelCheckpoint, global_step_from_engine
from ignite.contrib.handlers import TensorboardLogger


In [27]:
device = torch.device(dev)

In [28]:
dataset = ShapeNetPart(2048, 'train', task='seg')
N = len(dataset)

train_ds = Subset(dataset, list(range(2)))
val_ds = Subset(dataset, [i for i in range(N-2, N)])

In [29]:
train_loader = DataLoader(train_ds, batch_size=2)
val_loader = DataLoader(val_ds, batch_size=2)


In [30]:
model = Net(k=16, emb_dim=1024, ).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=5e-3)
criterion = nn.CrossEntropyLoss()

In [31]:
miou = JaccardIndex(50).cuda(dev)

In [32]:
for epoch in range(100):
    optimizer.zero_grad()
    total_loss, total_iou = 0, 0
    for b, (x, y) in enumerate(train_loader):
        x, y = x.cuda(dev), y.cuda(dev)
        y_out = model(x)
        loss = criterion(y_out, y)
        total_loss += loss
        preds = torch.argmax(y_out, dim=1)
        total_iou += miou(preds, y)
        loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print(f"{epoch} Loss:", (total_loss / 4).item(),
              " mIOU:", (total_iou / 4).item())


0 Loss: 0.9777383208274841  mIOU: 0.0007681986317038536
10 Loss: 0.2872034013271332  mIOU: 0.005980056244879961
20 Loss: 0.26222798228263855  mIOU: 0.00602493854239583
30 Loss: 0.2574349641799927  mIOU: 0.00583995645865798
40 Loss: 0.2558836042881012  mIOU: 0.005877274088561535
50 Loss: 0.25441649556159973  mIOU: 0.0059401593171060085
60 Loss: 0.25282928347587585  mIOU: 0.005973484832793474
70 Loss: 0.2522369623184204  mIOU: 0.005732319783419371
80 Loss: 0.25165092945098877  mIOU: 0.005835120566189289
90 Loss: 0.25097793340682983  mIOU: 0.005756357219070196


KeyboardInterrupt: 

In [None]:
trainer = create_supervised_trainer(model, optimizer, criterion, device)

cm_metric = ConfusionMatrix(num_classes=50)
val_metrics = {"IoU": mIoU(cm_metric), "loss": Loss(criterion)}

train_evaluator = create_supervised_evaluator(
    model, metrics=val_metrics, device=device)
val_evaluator = create_supervised_evaluator(
    model, metrics=val_metrics, device=device)


In [None]:
# How many batches to wait before logging training status
log_interval = 2


@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
    print(
        f"Epoch[{engine.state.epoch}], Iter[{engine.state.iteration}] Loss: {engine.state.output:.2f}")


@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    train_evaluator.run(train_loader)
    metrics = train_evaluator.state.metrics
    print(
        f"Training Results - Epoch[{trainer.state.epoch}] Avg IoU: {metrics['IoU']:.2f} Avg loss: {metrics['loss']:.2f}")


@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    val_evaluator.run(val_loader)
    metrics = val_evaluator.state.metrics
    print(
        f"Validation Results - Epoch[{trainer.state.epoch}] Avg IoU: {metrics['IoU']:.2f} Avg loss: {metrics['loss']:.2f}")


In [None]:
# Score function to return current value of any metric we defined above in val_metrics
def score_function(engine):
    return engine.state.metrics["IoU"]


# Checkpoint to store n_saved best models wrt score function
model_checkpoint = ModelCheckpoint(
    "checkpoint",
    n_saved=2,
    filename_prefix="best",
    score_function=score_function,
    score_name="IoU",
    require_empty=False,
    global_step_transform=global_step_from_engine(
        trainer),  # helps fetch the trainer's state
)

# Save the model after every epoch of val_evaluator is completed
val_evaluator.add_event_handler(
    Events.COMPLETED, model_checkpoint, {"model": model})


<ignite.engine.events.RemovableEventHandle at 0x7f2c1c3d8d90>

In [None]:
# Define a Tensorboard logger
tb_logger = TensorboardLogger(log_dir="tb-logger")

# Attach handler to plot trainer's loss every 100 iterations
tb_logger.attach_output_handler(
    trainer,
    event_name=Events.ITERATION_COMPLETED(every=100),
    tag="training",
    output_transform=lambda loss: {"batch_loss": loss},
)

# Attach handler for plotting both evaluators' metrics after every epoch completes
for tag, evaluator in [("training", train_evaluator), ("validation", val_evaluator)]:
    tb_logger.attach_output_handler(
        evaluator,
        event_name=Events.EPOCH_COMPLETED,
        tag=tag,
        metric_names="all",
        global_step_transform=global_step_from_engine(trainer),
    )


In [None]:
trainer.run(train_loader, max_epochs=10)

Training Results - Epoch[1] Avg IoU: 0.00 Avg loss: 2.88
Validation Results - Epoch[1] Avg IoU: 0.00 Avg loss: 3.32
Epoch[2], Iter[2] Loss: 3.39
Training Results - Epoch[2] Avg IoU: 0.01 Avg loss: 15.34
Validation Results - Epoch[2] Avg IoU: 0.01 Avg loss: 13.66
Training Results - Epoch[3] Avg IoU: 0.00 Avg loss: 10.26
Validation Results - Epoch[3] Avg IoU: 0.00 Avg loss: 20.52
Epoch[4], Iter[4] Loss: 3.13
Training Results - Epoch[4] Avg IoU: 0.00 Avg loss: 3.28
Validation Results - Epoch[4] Avg IoU: 0.00 Avg loss: 4.41
Training Results - Epoch[5] Avg IoU: 0.00 Avg loss: 3.85
Validation Results - Epoch[5] Avg IoU: 0.00 Avg loss: 3.92
Epoch[6], Iter[6] Loss: 2.05
Training Results - Epoch[6] Avg IoU: 0.00 Avg loss: 5.07
Validation Results - Epoch[6] Avg IoU: 0.00 Avg loss: 13.04
Training Results - Epoch[7] Avg IoU: 0.02 Avg loss: 1.33
Validation Results - Epoch[7] Avg IoU: 0.02 Avg loss: 5.44
Epoch[8], Iter[8] Loss: 1.41
Training Results - Epoch[8] Avg IoU: 0.02 Avg loss: 1.13
Validation

State:
	iteration: 10
	epoch: 10
	epoch_length: 1
	max_epochs: 10
	output: 1.2469780445098877
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>

In [None]:
# Let's close the logger and inspect our results
tb_logger.close()

%load_ext tensorboard

%tensorboard - -logdir = .


In [None]:
# At last we can view our best models
!ls checkpoints
