In [None]:
import torch
try:
    import torchbearer
except:
    !pip install torchbearer
    import torchbearer
from torchbearer import Trial
from torchbearer import callbacks
from torch import nn
import torch.nn.functional as F
from torchbearer.callbacks.torch_scheduler import CosineAnnealingLR
from torchbearer.callbacks import Callback
import numpy as np
import h5py
from torch.utils.data import Dataset, DataLoader
import os
from time import time

In [None]:
print(torch.version.cuda)
print(torch.__version__)
print(torchbearer.__version__)

In [None]:
@callbacks.on_start_epoch
def lr_print(state):
    opt = state[torchbearer.OPTIMIZER]
    print(opt.param_groups[0]['lr'])
        
@callbacks.on_end_epoch
def save_state(state):
    epoch = state[torchbearer.EPOCH]
    torch.save(state[torchbearer.MODEL].state_dict(), f'/kaggle/working/trained_{epoch}.ckpt')
    print('Model params saved, epoch '+str(epoch))
    
@callbacks.on_end_epoch
def miou_val(state):
    model = state[torchbearer.MODEL]
    miou = calc_miou(model, valloader)
    print(miou)

In [None]:
from sklearn.metrics import jaccard_score

def calc_miou(model, dataloader):
    n_models = len(dataloader.dataset)
    y_preds = np.empty(n_models*2048, dtype=int)
    y_trues = np.empty(n_models*2048, dtype=int)
    
    for i,((data, labels), pids) in enumerate(dataloader):
        batch_size = data.shape[0]
        y_pred = model(data.to('cuda'),labels.to('cuda'))
        y_pred = torch.flatten(y_pred.contiguous(), 0, 1)
        _, y_pred = y_pred.max(1)
        y_true = torch.flatten(pids.contiguous(), 0, 1)
        
        y_preds[i*32*2048 : i*32*2048 + batch_size*2048] = y_pred.cpu().numpy()
        y_trues[i*32*2048 : i*32*2048 + batch_size*2048] = y_true.cpu().numpy()
    
    print(np.unique(y_preds, return_counts=True))
    print(np.unique(y_trues, return_counts=True))
    
    iou = jaccard_score(y_trues, y_preds, average=None)*100
    
    n_parts = [4, 2, 2, 4, 4, 3, 3, 2, 4, 2, 6, 2, 3, 3, 3, 3]
    cur = 0
    miou = []
    for n in n_parts:
        miou.append(np.mean(iou[cur:cur+n]))
        cur += n
    return iou, miou

In [None]:
class ShapeNetPartDataset(Dataset):
    
    def __init__(self, dataset_path, partition_name=None):
        
        data = torch.empty(0)
        labels = pids = torch.empty(0, dtype=torch.long)
        n_files = 0
        
        for root, dirs, files in os.walk(dataset_path):
            for filename in files:
                if filename.endswith('.h5') and (partition_name is None or partition_name in filename):
                    print('Loading '+filename)                    
                    f = h5py.File(os.path.join(dataset_path, filename), 'r')            
                    f_data = torch.as_tensor(f['data'][()])
                    f_labels = torch.as_tensor(f['label'][()], dtype=torch.long).squeeze(-1)
                    f_pids = torch.as_tensor(f['pid'][()], dtype=torch.long)
                    data = torch.cat((data, f_data))
                    labels = torch.cat((labels, f_labels))
                    pids = torch.cat((pids, f_pids))
                    n_files += 1
        self.size = len(labels)
        self.data = data.permute(0,2,1) # [models, points, dims] -> [models, dims, points]
        self.labels = F.one_hot(labels).float()
        self.pids = pids
        strprt = '' if partition_name is None else partition_name + ' ' 
        print(str(self.data.shape[0]) + ' ' + strprt + 'models loaded from ' + str(n_files) + ' files.')
          
    def __getitem__(self, index):
        if index >= len(self):
            raise IndexError(f'self.__class__.__name__ index out of range.')

        return (self.data[index], self.labels[index]), self.pids[index]

    def __len__(self):
        return self.size

In [None]:
dataset_path = '/kaggle/input/shapenet-part-seg-hdf5-data/shapenet_part_seg_hdf5_data/shapenet_part_seg_hdf5_data'

trainset = ShapeNetPartDataset(dataset_path, 'train')
#valset = ShapeNetPartDataset(dataset_path, 'val')
testset = ShapeNetPartDataset(dataset_path, 'test')
trainloader = DataLoader(trainset, batch_size=32, shuffle=True, drop_last=True)
#valloader = DataLoader(valset, batch_size=32, shuffle=True)
testloader = DataLoader(testset, batch_size=32, shuffle=True)

for (data, labels), pids in testloader:
    print(data.shape)
    print(labels.shape)
    print(pids.shape)
    break

In [None]:
def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

def get_graph_feature(x, k=20, idx=None, dim9=False):
    batch_size = x.size(0)
    num_points = x.size(2)
    
    x = x.view(batch_size, -1, num_points)
    
    if idx is None:
        if dim9 == False:
            idx = knn(x, k=k)   # (batch_size, num_points, k)
        else:
            idx = knn(x[:, 6:], k=k)
    device = torch.device('cuda')
    idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1)*num_points    

    idx = idx + idx_base

    idx = idx.view(-1)
 
    _, num_dims, _ = x.size()
    
    x = x.transpose(2, 1).contiguous()   # (batch_size, num_points, num_dims)  -> (batch_size*num_points, num_dims) #   batch_size * num_points * k + range(0, batch_size*num_points)
    feature = x.view(batch_size*num_points, -1)[idx, :]
    feature = feature.view(batch_size, num_points, k, num_dims) 
    x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
    feature = torch.cat((feature-x, x), dim=3).permute(0, 3, 1, 2).contiguous()
    return feature      # (batch_size, 2*num_dims, num_points, k)

class Transform_Net(nn.Module):
    def __init__(self):
        super(Transform_Net, self).__init__()
        self.k = 3

        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm1d(1024)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv1d(128, 1024, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))

        self.linear1 = nn.Linear(1024, 512, bias=False)
        self.bn3 = nn.BatchNorm1d(512)
        self.linear2 = nn.Linear(512, 256, bias=False)
        self.bn4 = nn.BatchNorm1d(256)

        self.transform = nn.Linear(256, 3*3)
        torch.nn.init.constant_(self.transform.weight, 0)
        torch.nn.init.eye_(self.transform.bias.view(3, 3))

    def forward(self, x):
        batch_size = x.size(0)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 128, num_points, k)
        x = x.max(dim=-1, keepdim=False)[0]     # (batch_size, 128, num_points, k) -> (batch_size, 128, num_points)

        x = self.conv3(x)                       # (batch_size, 128, num_points) -> (batch_size, 1024, num_points)
        x = x.max(dim=-1, keepdim=False)[0]     # (batch_size, 1024, num_points) -> (batch_size, 1024)

        x = F.leaky_relu(self.bn3(self.linear1(x)), negative_slope=0.2)     # (batch_size, 1024) -> (batch_size, 512)
        x = F.leaky_relu(self.bn4(self.linear2(x)), negative_slope=0.2)     # (batch_size, 512) -> (batch_size, 256)

        x = self.transform(x)                   # (batch_size, 256) -> (batch_size, 3*3)
        x = x.view(batch_size, 3, 3)            # (batch_size, 3*3) -> (batch_size, 3, 3)

        return x

In [None]:
class DGCNN_partseg(nn.Module):
    def __init__(self, k=20, emb_dims=1024, dropout=0.5, seg_num_all = 27):
        super(DGCNN_partseg, self).__init__()
        self.k = k
        self.transform_net = Transform_Net()
        
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.bn3 = nn.BatchNorm2d(64)
        self.bn4 = nn.BatchNorm2d(64)
        self.bn5 = nn.BatchNorm2d(64)
        self.bn6 = nn.BatchNorm1d(emb_dims)
        self.bn7 = nn.BatchNorm1d(64)
        self.bn8 = nn.BatchNorm1d(256)
        self.bn9 = nn.BatchNorm1d(256)
        self.bn10 = nn.BatchNorm1d(128)

        self.conv1 = nn.Sequential(nn.Conv2d(6, 64, kernel_size=1, bias=False),
                                   self.bn1,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv2 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn2,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv3 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn3,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv4 = nn.Sequential(nn.Conv2d(64, 64, kernel_size=1, bias=False),
                                   self.bn4,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv5 = nn.Sequential(nn.Conv2d(64*2, 64, kernel_size=1, bias=False),
                                   self.bn5,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv6 = nn.Sequential(nn.Conv1d(192, emb_dims, kernel_size=1, bias=False),
                                   self.bn6,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv7 = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
                                   self.bn7,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv8 = nn.Sequential(nn.Conv1d(1280, 256, kernel_size=1, bias=False),
                                   self.bn8,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp1 = nn.Dropout(p=dropout)
        self.conv9 = nn.Sequential(nn.Conv1d(256, 256, kernel_size=1, bias=False),
                                   self.bn9,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.dp2 = nn.Dropout(p=dropout)
        self.conv10 = nn.Sequential(nn.Conv1d(256, 128, kernel_size=1, bias=False),
                                   self.bn10,
                                   nn.LeakyReLU(negative_slope=0.2))
        self.conv11 = nn.Conv1d(128, seg_num_all, kernel_size=1, bias=False)
        

    def forward(self, x, l):
        batch_size = x.size(0)
        num_points = x.size(2)
        x0 = get_graph_feature(x, k=self.k)     # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        t = self.transform_net(x0)              # (batch_size, 3, 3)
        x = x.transpose(2, 1)                   # (batch_size, 3, num_points) -> (batch_size, num_points, 3)
        x = torch.bmm(x, t)                     # (batch_size, num_points, 3) * (batch_size, 3, 3) -> (batch_size, num_points, 3)
        x = x.transpose(2, 1)                   # (batch_size, num_points, 3) -> (batch_size, 3, num_points)
        
        x = get_graph_feature(x, k=self.k)      # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)
        x = self.conv1(x)                       # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv2(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x1 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = get_graph_feature(x1, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv3(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x = self.conv4(x)                       # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points, k)
        x2 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
        
        x = get_graph_feature(x2, k=self.k)     # (batch_size, 64, num_points) -> (batch_size, 64*2, num_points, k)
        x = self.conv5(x)                       # (batch_size, 64*2, num_points, k) -> (batch_size, 64, num_points, k)
        x3 = x.max(dim=-1, keepdim=False)[0]    # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)

        x = torch.cat((x1, x2, x3), dim=1)      # (batch_size, 64*3, num_points)
        
        x = self.conv6(x)                       # (batch_size, 64*3, num_points) -> (batch_size, emb_dims, num_points)
        x = x.max(dim=-1, keepdim=True)[0]      # (batch_size, emb_dims, num_points) -> (batch_size, emb_dims, 1)
        
        l = l.view(batch_size, -1, 1)           # (batch_size, num_categoties, 1)
        l = self.conv7(l)                       # (batch_size, num_categoties, 1) -> (batch_size, 64, 1)
        
        x = torch.cat((x, l), dim=1)            # (batch_size, 1088, 1)
        x = x.repeat(1, 1, num_points)          # (batch_size, 1088, num_points)
        
        x = torch.cat((x, x1, x2, x3), dim=1)   # (batch_size, 1088+64*3, num_points)
        
        x = self.conv8(x)                       # (batch_size, 1088+64*3, num_points) -> (batch_size, 256, num_points)
        x = self.dp1(x)
        x = self.conv9(x)                       # (batch_size, 256, num_points) -> (batch_size, 256, num_points)
        x = self.dp2(x)
        x = self.conv10(x)                      # (batch_size, 256, num_points) -> (batch_size, 128, num_points)
        x = self.conv11(x)                      # (batch_size, 256, num_points) -> (batch_size, seg_num_all, num_points)
        
        return x.permute(0,2,1)

In [None]:
def loss_fn(ypred, ytrue, smoothing=True):
    ''' Calculate cross entropy loss, apply label smoothing if needed. '''
    
    ypred = ypred.contiguous().view(-1, ypred.shape[-1])
    ytrue = ytrue.contiguous().flatten()

    if smoothing:
        eps = 0.2
        n_class = ypred.size(1)

        one_hot = torch.zeros_like(ypred).scatter(1, ytrue.view(-1, 1), 1)
        one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
        log_prb = F.log_softmax(ypred, dim=1)
        loss = -(one_hot * log_prb).sum(dim=1).mean()
    else:
        loss = F.cross_entropy(ypred, ytrue, reduction='mean')

    return loss

In [None]:
model=DGCNN_partseg()
model.load_state_dict(torch.load('/kaggle/input/shapenet-part-seg-hdf5-data/pretrained_partseg_199.weights'))
model.conv11 = nn.Conv1d(128, 50, kernel_size=1, bias=False)
opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

#ckpt = torch.load('/kaggle/input/shapenet-part-seg-hdf5-data/trained_partseg_ca_99.ckpt')
#model.load_state_dict(ckpt['model'])
#opt.load_state_dict(ckpt['opt'])

scheduler = CosineAnnealingLR(200, eta_min=1e-3)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
trial = Trial(model, opt, loss_fn, callbacks=[scheduler, lr_print], metrics=['loss']).to(device)

trial.with_generators(trainloader)
trial.run(epochs=200)

torch.save({'model': model.state_dict(),
            'opt': opt.state_dict()},
           '/kaggle/working/trained_partseg_ca_199.ckpt')

In [None]:
model.eval()
with torch.no_grad():
    iou, miou = calc_miou(model, testloader)
    print(iou)
    print(miou)