In [1]:
import torch
import torch.nn as nn
from torch_geometric.data import Data
from torch.utils.data import Dataset, DataLoader

import MinkowskiEngine as ME
import MinkowskiEngine.MinkowskiFunctional as MF


import open3d as o3d
import os
import torch
import einops
from tqdm import tqdm

In [2]:
def category_indices(base_path, train_folder, gt_folder, train_prefix=None, gt_prefix=None, format=".txt"):
    
    '''
        This function assumes the following file structure:
        
        base_path
        -train_folder
            -train_prefix + {i} + .txt
        -gt_folder
            -gt_prefix + {i} + .txt
            
        And the following file contents:
        
        GT:
            a single integer corresponding to a category
        Train:
            N lines containing 3 comma-separated floats corresponding to point coordinates
    '''
    import os
    from tqdm import tqdm
    
    #
    train_prefix = train_prefix or train_folder 
    gt_prefix = gt_prefix or gt_folder 
    
    #
    train_path = os.path.join(base_path, train_folder)
    gt_path = os.path.join(base_path, gt_folder)
    
    #
    gt_file = lambda i : os.path.join(gt_path, gt_prefix + str(i) + format)
    train_file = lambda i : os.path.join(train_path, train_prefix + str(i) + format)
    
    #
    categories = {}
    
    #
    samples = os.listdir(train_path)
    N = len(samples)
    
    for i in tqdm(range(1, N+1)):
        with open(gt_file(i)) as GT:
            
            cat = GT.readline()[0]
            if cat not in categories.keys(): 
                categories[cat] = [i]
            else:
                categories[cat].append(i)
    
    save_path = os.path.join(base_path,"indices.txt");
    
    with open(save_path, "w") as F:
        for key in categories.keys():
            F.write(",".join(map(str, categories[key])) + "\n")

In [3]:
# Creating the dataset splits
base = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset/training/"
category_indices(base, "pointCloud", "GTpointCloud")

100%|██████████| 46000/46000 [00:00<00:00, 87914.71it/s]


In [4]:
class SHREC2022Dataset(torch.utils.data.Dataset):
    
    def __init__(self, path, train=True, valid=False, valid_split=0.2):
        
        self.path = os.path.join(path, "training" if train else "test")
        self.pc_prefix = "pointCloud"
        self.gt_prefix = "GTpointCloud"
        self.format = ".txt"
        self.valid = valid if train else False
        self.size = 0 if train else len(os.listdir(self.path))
        
        #check if an existing train-validation split matches the one given
        split_info_file = os.path.join(path, "training/split_info.txt")
        self.t_savefile = os.path.join(path, "training/train_split.txt")
        self.v_savefile = os.path.join(path, "training/valid_split.txt")
        if os.path.exists(split_info_file):
            with open(split_info_file) as F:
                v, vsize, tsize = list(map(float, F.readline().split(',')))

                if v == valid_split:
                    print("Specified split already exists. Using the existing one.")
                    self.size = int(vsize if self.valid else tsize)
                    return
        
        if train:
            print("Creating a new train-validation split.")
            import random
            with open(os.path.join(path, "training/indices.txt")) as F, open(self.t_savefile, "w") as T,\
                open(self.v_savefile, "w") as V, open(split_info_file, "w") as I:

                lines = F.readlines()
                cat_sz = len(lines[0].split(","))
                train_sz = int(cat_sz * (1-valid_split))

                v_indices = []
                t_indices = []

                for line in lines:
                    line = list(map(int, line.split(",")))
                    random.shuffle(line)
                    v_indices = v_indices + line[train_sz:]
                    t_indices = t_indices + line[:train_sz]

                self.size = len(v_indices) if valid else len(t_indices)
                T.write("\n".join(map(str, t_indices)))
                V.write("\n".join(map(str, v_indices)))
                
                I.write(str(valid_split)+','+str(len(v_indices))+','+str(len(t_indices)))

                
    def __getitem__(self, index):
        
        with open(self.v_savefile if self.valid else self.t_savefile, "r") as F:
            index = F.readlines()[index]
            index = int(index) if '\n' not in index else int(index[:-1])
         
        
        #assembling the file name for the data and labels
        pc_name = os.path.join(self.path, self.pc_prefix, self.pc_prefix + str(index) + self.format)
        gt_name = os.path.join(self.path, self.gt_prefix, self.gt_prefix + str(index) + self.format)
        
        #parsing the point cloud
        pcloud = self.parse_point_cloud(pc_name)
        label = self.parse_label(gt_name)

        data = {"x": pcloud, "y": label['data']}
        
        for t in self.transform:
            data = t(data)
        
        return data
        
    def __len__(self):
        
        return self.size
    
    def transform(self, data):
        
        return self.unit_sphere_normalize(data)
        
    def unit_sphere_normalize(self, x):
        
        max_norm = (x["x"]*x["x"]).sum(-1).max().sqrt()
        x["x"] /= max_norm
        
        x["norm_factor"] = max_norm
        
        return x



def minkowski_collate(list_data):
    coordinates, features, labels = ME.utils.sparse_collate(
        [d['x'] for d in list_data],
        [d['x'] for d in list_data],
        [d['y'][0].unsqueeze(0) for d in list_data],
        dtype = torch.float32
    )
    
    return {
        "coordinates": coordinates, 
        "features"   : features,
        "labels"     : labels
    }

In [5]:
path = "/home/ioannis/Desktop/programming/data/SHREC/SHREC2022/dataset"
t_dataset = SHREC2022Dataset(path, train=True, valid=False, valid_split=0.2)
v_dataset = SHREC2022Dataset(path, train=True, valid=True, valid_split=0.2)

sample1 = t_dataset[0]
sample2 = v_dataset[0]
# show_point_cloud_o3d(sample1['x'])
# show_point_cloud_o3d(sample2['x'])
print(len(t_dataset))
print(len(v_dataset))

Specified split already exists. Using the existing one.
Specified split already exists. Using the existing one.
36800
9200


In [6]:
batch_size = 16
train_loader = DataLoader(t_dataset, batch_size=batch_size, shuffle=True, collate_fn = minkowski_collate, num_workers=8)
eval_loader = DataLoader(v_dataset, batch_size=batch_size, shuffle=False, collate_fn=minkowski_collate, num_workers=8)

In [7]:
# time to parse the whole dataset
from tqdm import tqdm
import time

t1 = time.time()
for batch in tqdm(train_loader):
    pass
print(time.time() - t1)

100%|██████████| 2300/2300 [01:27<00:00, 26.17it/s]

87.89322829246521





# Network

## PointNet

In [8]:
class MinkowskiPointNet(ME.MinkowskiNetwork):
    def __init__(self, in_channel, out_channel, embedding_channel=1024, dimension=3):
        ME.MinkowskiNetwork.__init__(self, dimension)
        self.conv1 = nn.Sequential(
            ME.MinkowskiLinear(3, 64, bias=False),
            ME.MinkowskiBatchNorm(64),
            ME.MinkowskiReLU(),
        )
        self.conv2 = nn.Sequential(
            ME.MinkowskiLinear(64, 64, bias=False),
            ME.MinkowskiBatchNorm(64),
            ME.MinkowskiReLU(),
        )
        self.conv3 = nn.Sequential(
            ME.MinkowskiLinear(64, 64, bias=False),
            ME.MinkowskiBatchNorm(64),
            ME.MinkowskiReLU(),
        )
        self.conv4 = nn.Sequential(
            ME.MinkowskiLinear(64, 128, bias=False),
            ME.MinkowskiBatchNorm(128),
            ME.MinkowskiReLU(),
        )
        self.conv5 = nn.Sequential(
            ME.MinkowskiLinear(128, embedding_channel, bias=False),
            ME.MinkowskiBatchNorm(embedding_channel),
            ME.MinkowskiReLU(),
        )
        self.max_pool = ME.MinkowskiGlobalMaxPooling()

        self.linear1 = nn.Sequential(
            ME.MinkowskiLinear(embedding_channel, 512, bias=False),
            ME.MinkowskiBatchNorm(512),
            ME.MinkowskiReLU(),
        )
        self.dp1 = ME.MinkowskiDropout()
        self.linear2 = ME.MinkowskiLinear(512, out_channel, bias=True)

    def forward(self, x: ME.TensorField):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.max_pool(x)
        x = self.linear1(x)
        x = self.dp1(x)
        return self.linear2(x).F


## Minkowski Encoder

In [9]:
class MinkowskiFCNN(ME.MinkowskiNetwork):
    def __init__(
        self,
        in_channel,
        out_channel,
        embedding_channel=1024,
        channels=(32, 48, 64, 96, 128),
        D=3,
    ):
        ME.MinkowskiNetwork.__init__(self, D)

        self.network_initialization(
            in_channel,
            out_channel,
            channels=channels,
            embedding_channel=embedding_channel,
            kernel_size=3,
            D=D,
        )
        self.weight_initialization()

    def get_mlp_block(self, in_channel, out_channel):
        return nn.Sequential(
            ME.MinkowskiLinear(in_channel, out_channel, bias=False),
            ME.MinkowskiBatchNorm(out_channel),
            ME.MinkowskiLeakyReLU(),
        )

    def get_conv_block(self, in_channel, out_channel, kernel_size, stride):
        return nn.Sequential(
            ME.MinkowskiConvolution(
                in_channel,
                out_channel,
                kernel_size=kernel_size,
                stride=stride,
                dimension=self.D,
            ),
            ME.MinkowskiBatchNorm(out_channel),
            ME.MinkowskiLeakyReLU(),
        )

    def network_initialization(
        self,
        in_channel,
        out_channel,
        channels,
        embedding_channel,
        kernel_size,
        D=3,
    ):
        self.mlp1 = self.get_mlp_block(in_channel, channels[0])
        self.conv1 = self.get_conv_block(
            channels[0],
            channels[1],
            kernel_size=kernel_size,
            stride=1,
        )
        self.conv2 = self.get_conv_block(
            channels[1],
            channels[2],
            kernel_size=kernel_size,
            stride=2,
        )

        self.conv3 = self.get_conv_block(
            channels[2],
            channels[3],
            kernel_size=kernel_size,
            stride=2,
        )

        self.conv4 = self.get_conv_block(
            channels[3],
            channels[4],
            kernel_size=kernel_size,
            stride=2,
        )
        self.conv5 = nn.Sequential(
            self.get_conv_block(
                channels[1] + channels[2] + channels[3] + channels[4],
                embedding_channel // 4,
                kernel_size=3,
                stride=2,
            ),
            self.get_conv_block(
                embedding_channel // 4,
                embedding_channel // 2,
                kernel_size=3,
                stride=2,
            ),
            self.get_conv_block(
                embedding_channel // 2,
                embedding_channel,
                kernel_size=3,
                stride=2,
            ),
        )

        self.pool = ME.MinkowskiMaxPooling(kernel_size=3, stride=2, dimension=D)

        self.global_max_pool = ME.MinkowskiGlobalMaxPooling()
        self.global_avg_pool = ME.MinkowskiGlobalAvgPooling()

        self.final = nn.Sequential(
            self.get_mlp_block(embedding_channel * 2, 512),
            ME.MinkowskiDropout(),
            self.get_mlp_block(512, 512),
            ME.MinkowskiLinear(512, out_channel, bias=True),
        )

        # No, Dropout, last 256 linear, AVG_POOLING 92%

    def weight_initialization(self):
        for m in self.modules():
            if isinstance(m, ME.MinkowskiConvolution):
                ME.utils.kaiming_normal_(m.kernel, mode="fan_out", nonlinearity="relu")

            if isinstance(m, ME.MinkowskiBatchNorm):
                nn.init.constant_(m.bn.weight, 1)
                nn.init.constant_(m.bn.bias, 0)

    def forward(self, x: ME.TensorField):
        x = self.mlp1(x)
        y = x.sparse()

        y = self.conv1(y)
        y1 = self.pool(y)

        y = self.conv2(y1)
        y2 = self.pool(y)

        y = self.conv3(y2)
        y3 = self.pool(y)

        y = self.conv4(y3)
        y4 = self.pool(y)

        x1 = y1.slice(x)
        x2 = y2.slice(x)
        x3 = y3.slice(x)
        x4 = y4.slice(x)

        x = ME.cat(x1, x2, x3, x4)

        y = self.conv5(x.sparse())
        x1 = self.global_max_pool(y)
        x2 = self.global_avg_pool(y)

        return self.final(ME.cat(x1, x2)).F

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

minkpointnet = MinkowskiFCNN(in_channel = 3, out_channel = 5).to(device)

cuda:0


# Training Loop

In [11]:
def create_input_batch(batch, device="cuda", quantization_size=0.05):
    batch["coordinates"][:, 1:] = batch["coordinates"][:, 1:] / quantization_size
    return ME.TensorField(
        coordinates=batch["coordinates"],
        features=batch["features"],
        device=device
    )

In [15]:
import time

cls_loss = torch.nn.CrossEntropyLoss()

param_loss = torch.nn.MSELoss()

num_epochs = 20
optimizer = torch.optim.Adam(minkpointnet.parameters(), lr=1e-3)



for i in range(num_epochs):
    
    m_loss = 0
    acc = 0
    for batch in tqdm(train_loader):
        #print("Loading data time: ", time.time() - t_e)
        
        optimizer.zero_grad()
        labels = batch["labels"].long().to(device)
        
        batch_size = labels.shape[0]
        
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        #minknet_input = ME.TensorField(
        #    coordinates=batch["coordinates"].to(device),
        #    features   =batch["features"].to(device)
        #)

        pred = minkpointnet(minknet_input)
        #print("Activation time: ",time.time() - t)
        
        loss = cls_loss(pred, labels)
        
        loss.backward()
        
        optimizer.step()
        
        m_loss += loss.item() #/ batch_size
        
        acc += (torch.max(pred, dim=-1).indices == labels).sum() / batch_size
        

        
    m_loss /= len(train_loader)
    acc /= len(train_loader)
    
    print(f"epoch {i} - loss: {m_loss}")
    print(f"train accuracy {acc}")
    
    if (i+1)%1 == 0:
        acc = 0
        m_loss=0
        minkpointnet.eval()
        with torch.no_grad():
            for ebatch in eval_loader:
            
                labels = ebatch["labels"].long().to(device)

                batch_size = labels.shape[0]

                minknet_input = create_input_batch(
                    ebatch, 
                    device=device,
                    quantization_size=0.05
                )


                pred = minkpointnet(minknet_input)

                loss = cls_loss(pred, labels)

                m_loss += loss.item() #/ batch_size

                acc += (torch.max(pred, dim=-1).indices == labels).sum() / batch_size


            m_loss /= len(eval_loader)
            acc /= len(eval_loader)
            print(f"validation_loss: {m_loss}")
            print(f"validation_acc : {acc}")
        minkpointnet.train()
    

100%|██████████| 2300/2300 [02:23<00:00, 16.06it/s]

epoch 0 - loss: 0.4326362850063521
train accuracy 0.8411141633987427





validation_loss: 0.534489544911877
validation_acc : 0.800000011920929


100%|██████████| 2300/2300 [02:26<00:00, 15.71it/s]

epoch 1 - loss: 0.3273414047167677
train accuracy 0.8791847825050354





validation_loss: 0.6103396369776
validation_acc : 0.783152163028717


100%|██████████| 2300/2300 [02:22<00:00, 16.16it/s]

epoch 2 - loss: 0.24028613207983257
train accuracy 0.9142934679985046





validation_loss: 0.6943070894609327
validation_acc : 0.7771739363670349


100%|██████████| 2300/2300 [02:21<00:00, 16.25it/s]

epoch 3 - loss: 0.1814176823557152
train accuracy 0.9369293451309204





validation_loss: 0.7306370152549251
validation_acc : 0.7794565558433533


100%|██████████| 2300/2300 [02:22<00:00, 16.15it/s]

epoch 4 - loss: 0.1409721631496011
train accuracy 0.950951099395752





validation_loss: 0.7489643728368632
validation_acc : 0.7871739268302917


100%|██████████| 2300/2300 [02:22<00:00, 16.19it/s]

epoch 5 - loss: 0.11780726776245738
train accuracy 0.9588587284088135





validation_loss: 0.8387530057608028
validation_acc : 0.7830435037612915


100%|██████████| 2300/2300 [02:30<00:00, 15.26it/s]

epoch 6 - loss: 0.09607708493594344
train accuracy 0.9666032791137695





validation_loss: 0.8981141074746847
validation_acc : 0.7873913049697876


100%|██████████| 2300/2300 [02:22<00:00, 16.09it/s]

epoch 7 - loss: 0.08784098507841523
train accuracy 0.968668520450592





validation_loss: 0.8608694858460323
validation_acc : 0.7872826457023621


100%|██████████| 2300/2300 [02:21<00:00, 16.29it/s]

epoch 8 - loss: 0.0763776102642473
train accuracy 0.9742391705513





validation_loss: 0.9793378312728853
validation_acc : 0.7733696103096008


100%|██████████| 2300/2300 [02:20<00:00, 16.33it/s]

epoch 9 - loss: 0.07015811979113926
train accuracy 0.9764402508735657





validation_loss: 0.9559305762685836
validation_acc : 0.786195695400238


100%|██████████| 2300/2300 [02:21<00:00, 16.28it/s]

epoch 10 - loss: 0.060614798235107195
train accuracy 0.9796739220619202





validation_loss: 1.002106832701551
validation_acc : 0.7865217328071594


100%|██████████| 2300/2300 [02:21<00:00, 16.31it/s]

epoch 11 - loss: 0.0595244886492646
train accuracy 0.9802989363670349





validation_loss: 0.9706727385196997
validation_acc : 0.7850000262260437


100%|██████████| 2300/2300 [02:21<00:00, 16.27it/s]

epoch 12 - loss: 0.05313013940632632
train accuracy 0.9819021821022034





validation_loss: 1.0979773423083774
validation_acc : 0.7754347920417786


100%|██████████| 2300/2300 [02:21<00:00, 16.26it/s]

epoch 13 - loss: 0.0461304036746309
train accuracy 0.984646737575531





validation_loss: 1.0906997536892153
validation_acc : 0.7866304516792297


100%|██████████| 2300/2300 [02:21<00:00, 16.24it/s]

epoch 14 - loss: 0.044197986306453066
train accuracy 0.9851087331771851





validation_loss: 1.1075005747884026
validation_acc : 0.7836956977844238


100%|██████████| 2300/2300 [02:21<00:00, 16.21it/s]

epoch 15 - loss: 0.04246408314610048
train accuracy 0.9853532910346985





validation_loss: 1.1423766938583297
validation_acc : 0.7781521677970886


100%|██████████| 2300/2300 [02:21<00:00, 16.24it/s]

epoch 16 - loss: 0.042558459441396065
train accuracy 0.9859783053398132





validation_loss: 1.1205272277582274
validation_acc : 0.7783696055412292


100%|██████████| 2300/2300 [02:22<00:00, 16.13it/s]

epoch 17 - loss: 0.040303393152497174
train accuracy 0.9866847991943359





validation_loss: 1.1245689939019148
validation_acc : 0.7731521725654602


100%|██████████| 2300/2300 [02:26<00:00, 15.67it/s]

epoch 18 - loss: 0.03355003540808154
train accuracy 0.9887500405311584





validation_loss: 1.1468779277158003
validation_acc : 0.782608687877655


100%|██████████| 2300/2300 [02:25<00:00, 15.81it/s]

epoch 19 - loss: 0.0353776682241859
train accuracy 0.9886141419410706





validation_loss: 1.1089505644129467
validation_acc : 0.7823913097381592


In [None]:
## Creating a confusion matrix for the data
# Creating a dataloader with batch size = 1
conf_loader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn = minkowski_collate, num_workers=8)

preds = torch.zeros(40, 40)

acc = 0

with torch.no_grad():
    minkpointnet.eval()
    
    for batch in tqdm(conf_loader):
    
        labels = batch["labels"].long().to(device)
        
        batch_size = labels.shape[0]
        
        minknet_input = create_input_batch(
            batch, 
            device=device,
            quantization_size=0.05
        )
        
        pred = minkpointnet(minknet_input)
        
        pred_ind = pred.max(dim=-1).indices.cpu()

        preds[labels.item(), pred_ind] += 1
        
        loss = cls_loss(pred, labels)
                
        m_loss += loss.item() #/ batch_size
        
        acc += (torch.max(pred, dim=-1).indices == labels).sum() / batch_size
        


In [None]:
import matplotlib.pyplot as plt

def confusion_matrix(data, xlabels: list, ylabels: list):

    assert data.dim() in [2,3], "The input must be a square matrix or a batch of square matrices"
    
    if data.dim() == 3:

        B, N, M = data.shape
        assert N == M
        assert len(xlabels) == len(ylabels) == B or len(xlabels) == len(ylabels) == N
        try:
            for subx, suby in zip(xlabels, ylabels):
                assert len(subx) ==  N and len(suby) == N
        except AssertionError:
            xlabels = [xlabels for i in range(B)]
            ylabels = [ylabels for i in range(B)]


        rows, cols = 1, B
        if B % 3 == 0:
            rows, cols = 3, B//3
        elif B % 2 == 0:
            rows, cols = 2, B//2

        ticks = [i for i in range(N)]

        fig, axes = plt.subplots(rows, cols)
        for r in range(rows):
            for c in range(cols):
                i = r*cols + c
                axes[r, c].matshow(data[i])
                plt.xticks(ticks = ticks, labels = xlabels[i])
                plt.yticks(ticks = ticks, labels = ylabels[i])
        
    else:

        N, M = data.shape
        assert N == M
        assert len(xlabels) == N and len(ylabels) == N

        ticks = [i for i in range(N)]

        plt.matshow(data)
        plt.xticks(ticks = ticks, labels = xlabels)
        plt.yticks(ticks = ticks, labels = ylabels)

    plt.show()

In [None]:
acc / len(conf_loader)