In [None]:
import torch
try:
    import torchbearer
except:
    !pip install torchbearer
    import torchbearer

from torchbearer import Trial
from torchbearer import callbacks
from torch import nn
import torch.nn.functional as F
from torchbearer.callbacks.torch_scheduler import LambdaLR
from torchbearer.callbacks import Callback
import numpy as np
import h5py
from torch.utils.data import Dataset, DataLoader
import os
from time import time

In [None]:
print(torch.__version__)

In [None]:
class LRPrinter(Callback):
    def __init__(self, opt):
        self._opt = opt

    def on_start_epoch(self, state):
        print(self._opt.param_groups[0]['lr'])
        
@callbacks.on_end_epoch
def save_state(state):
    epoch = state[torchbearer.EPOCH]
    torch.save(state[torchbearer.MODEL].state_dict(), f'H:/pretrained_part/pretrained_sn_{epoch}.weights')

In [None]:
def scale_to_unit_cube(vertices):
    # Mean-centre
    vertices -= np.mean(vertices, axis=0)
    # Scale
    vertices /= 2*np.amax(np.abs(vertices))
    
def voxelise(vertices, k):
    vox_ids = np.empty(vertices.shape[0], dtype=int)
    
    d_ind = lambda d : np.floor(d * k) if d < 1 else k-1
    
    # Ensures input vertices x,y,z between [0,1]
    for i,v in enumerate(vertices + 0.5):
        vox_ids[i] = np.sum([d_ind(v[i])*k**i for i in range(3)])
                           
    return vox_ids

def spread(index, k):
    return (index % k, np.floor(index/k) % k, np.floor(index/k**2))

def get_voxel_xyz(vox_id, k):
    x=y=z = 0    
    
    z = vox_id // k**2
    vox_id -= z*k**2
    y = vox_id // k
    vox_id -= y*k
    x = vox_id
    
    return np.asarray([x,y,z]) / k

def shuffle_voxels(vertices, vox_ids, k):
    shuffled_vertices = np.empty_like(vertices)
    
    ids = np.asarray([i for i in range(k**3)])
    shuffled_ids = np.random.permutation(ids)
    id_map = dict(np.stack((ids, shuffled_ids), axis=1))
    
    for i,v in enumerate(vertices):
        # translate to origin voxel
        xyz_from = get_voxel_xyz(vox_ids[i], k)
        # translate to destination voxel
        xyz_to = get_voxel_xyz(id_map[vox_ids[i]], k)
        shuffled_vertices[i] = v - xyz_from + xyz_to
    
    return shuffled_vertices
    
def generate_self_supervised(dataset_arr):
    id_list = []
    
    for pointcloud in dataset_arr:
        scale_to_unit_cube(pointcloud)
        vox_ids = voxelise(pointcloud, 3)
        id_list.append(vox_ids)
        
    return id_list

In [None]:
class H5Dataset(Dataset):
    def __init__(self, dataset_path, partition_name, num_points=1024):
        
        data, labels = torch.empty(0), torch.empty(0, dtype=torch.long)
        n_files = 0
        
        for root, dirs, files in os.walk(dataset_path):
            for filename in files:
                if filename.endswith('.h5') and partition_name in filename:
                    print('Loading '+filename)
                    f = h5py.File(dataset_path+filename, 'r')
                    f_data, f_labels = np.asarray(f['data']), np.asarray(f['label'])
                    np.apply_along_axis(np.random.shuffle, 1, f_data) # shuffle points
                    f_data   = torch.tensor(f_data)[:,:num_points,:].permute(0,2,1).float() # [models, points, dims] -> [models, dims, points]
                    f_labels = torch.tensor(f_labels).long()
                    f_labels = torch.squeeze(f_labels, -1)
                    data = torch.cat((data, f_data))
                    labels = torch.cat((labels, f_labels))
                    n_files += 1
                    
        self.size = len(labels)
        self.data = data
        print(str(data.shape[0]) + ' ' + partition_name + ' models loaded from ' + str(n_files) + ' files.')
        self.labels = labels
                    
    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError(f'self.__class__.__name__ index out of range.')

        return self.data[index], self.labels[index]

    def __len__(self):
        return self.size

In [None]:
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx


def get_graph_feature(x, k=20, idx=None):
    batch_size = x.size(0)
    num_points = x.size(2)
    x = x.view(batch_size, -1, num_points)
    if idx is None:
        idx = knn(x, k=k)   # (batch_size, num_points, k)
    device = torch.device('cuda:0')

    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points

    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()

    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
  
    return feature

class Transform_Net(nn.Module):
    def __init__(self):
        super(Transform_Net, self).__init__()
        self.k = 3

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)

        self.transform = nn.Linear(256, 3*3)
        torch.nn.init.constant_(self.transform.weight, 0)
        torch.nn.init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        batch_size = x.size(0)

        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        x = x.max(dim=-1, keepdim=False)[0]     # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        x = self.conv3(x)                       # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[0]     # (batch_size, 1024, num_points) -> (batch_size, 1024)

        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)     # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)     # (batch_size, 512) -> (batch_size, 256)

        x = self.transform(x)                   # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, 3, 3)            # (batch_size, 3*3) -> (batch_size, 3, 3)

        return x

In [None]:
class DGCNN_partseg(nn.Module):
    def __init__(self, k=20, emb_dims=1024, dropout=0.5, seg_num_all = 27):
        super(DGCNN_partseg, self).__init__()
        self.seg_num_all = seg_num_all
        self.k = k
        self.transform_net = Transform_Net()
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm1d(emb_dims)
        self.bn7 = nn.BatchNorm1d(64)
        self.bn8 = nn.BatchNorm1d(256)
        self.bn9 = nn.BatchNorm1d(256)
        self.bn10 = nn.BatchNorm1d(128)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, emb_dims, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=dropout)
        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   self.bn9,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp2 = nn.Dropout(p=dropout)
        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                   self.bn10,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv11 = nn.Conv1d(128, self.seg_num_all, kernel_size=1, bias=False)
        

    def forward(self, t):
        x, l = t
        batch_size = x.size(0)
        num_points = x.size(2)

        x0 = get_graph_feature(x, k=self.k)     # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = self.transform_net(x0)              # (batch_size, 3, 3)
        x = x.transpose(2, 1)                   # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)                     # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)                   # (batch_size, num_points, 3) -> (batch_size, 3, num_points)

        x = get_graph_feature(x, k=self.k)      # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x2, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv5(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = torch.cat((x1, x2, x3), dim=1)      # (batch_size, 64*3, num_points)

        x = self.conv6(x)                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]      # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)

        l = l.view(batch_size, -1, 1)           # (batch_size, num_categoties, 1)
        l = self.conv7(l)                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)

        x = torch.cat((x, l), dim=1)            # (batch_size, 1088, 1)
        x = x.repeat(1, 1, num_points)          # (batch_size, 1088, num_points)

        x = torch.cat((x, x1, x2, x3), dim=1)   # (batch_size, 1088+64*3, num_points)

        x = self.conv8(x)                       # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
        x = self.dp1(x)
        x = self.conv9(x)                       # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
        x = self.dp2(x)
        x = self.conv10(x)                      # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
        x = self.conv11(x)                      # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)
        
        return x.permute(0,2,1)

# Training

In [None]:
dataset = 'H:/shapenet57448xyzonly.npz'

with np.load(dataset) as sn:
    data = sn['data']

num_points = data.shape[0]
labels = np.zeros((num_points, 16))
rands = np.random.randint(0, 16, size=num_points)

labels[range(num_points),rands] = 1

voxids = 'H:/voxids.npz'

with np.load(voxids) as lbs:
    voxids = lbs['arr_0']

In [None]:
class ShapeNetDataset(Dataset):
    
    def __init__(self, models, labels, voxids, num_points=1024):
        
        data, vids = torch.empty((len(models),num_points,3)), torch.empty((len(models),num_points), dtype=torch.long)
        
        for i, pointcloud in enumerate(models):
            
            inds = np.random.permutation(range(2048))
            vs = voxids[i,inds[:num_points]]
            
            shuffled = shuffle_voxels(pointcloud[inds[:num_points]], vs, 3)
            
            data[i] = torch.as_tensor(shuffled)
            vids[i] = torch.as_tensor(vs)
                    
        self.size   = len(labels)
        self.data   = data.permute(0,2,1)
        self.labels = torch.as_tensor(labels).float()
        self.vids = vids
        print('Done')
          
    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError(f'self.__class__.__name__ index out of range.')

        return (self.data[index], self.labels[index]), self.vids[index]

    def __len__(self):
        return self.size

In [None]:
with open('H:/trainloader_partseg.pkl','rb') as f:
    train_loader = torch.load(f)

In [None]:
def loss_function(pred, gold, smoothing=True):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''
    
    pred = pred.contiguous().view(-1, pred.shape[-1])
    gold = gold.contiguous().flatten()

    if smoothing:
        eps = 0.2
        n_class = pred.size(1)

        one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(pred, dim=1)
        loss = -(one_hot * log_prb).sum(dim=1).mean()
    else:
        loss = F.cross_entropy(pred, gold, reduction='mean')

    return loss

In [None]:
model = DGCNN_partseg()
opt = torch.optim.Adam(model.parameters(), lr=0.0001)

def calculate_gamma_initial_lr(epoch):
    lr_initial = 0.0001
    factor = 0.5**epoch
    lr = lr_initial*factor
    if lr < 0.00001:
        return 0.1
    else:
        return factor

scheduler = LambdaLR(calculate_gamma_initial_lr)
lrprint = LRPrinter(opt)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trial = Trial(model, opt, loss_function, callbacks=[scheduler, lrprint, save_state], metrics=['loss']).to(device)

trial.with_generators(train_loader)
trial.run(epochs=200)