In [2]:
import os
from torch.utils.data import Dataset
import numpy as np
import h5py
# from google.colab import drive
import torch
import torch.utils.data
import argparse
# drive.mount('/content/drive')

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter']
class2label = {cls: i for i,cls in enumerate(classes)}

In [4]:
def getDataFiles(list_filename):
    return [line.rstrip() for line in open(list_filename)]

def load_h5(h5_filename):
    f = h5py.File(h5_filename)
    data = f['data'][:]
    label = f['label'][:]
    return (data, label)

def loadDataFile(filename):
    # print(filename)
    return load_h5('./' + filename)

In [5]:
def recognize_all_data(test_area = 5):
    ALL_FILES = getDataFiles('./indoor3d_sem_seg_hdf5_data/all_files.txt')
    room_filelist = [line.rstrip() for line in open('./indoor3d_sem_seg_hdf5_data/room_filelist.txt')]
    data_batch_list = []
    label_batch_list = []
    for h5_filename in ALL_FILES:
        data_batch, label_batch = loadDataFile(h5_filename)
        data_batch_list.append(data_batch)
        label_batch_list.append(label_batch)
    data_batches = np.concatenate(data_batch_list, 0)
    label_batches = np.concatenate(label_batch_list, 0)

    test_area = 'Area_' + str(test_area)
    train_idxs = []
    test_idxs = []
    for i, room_name in enumerate(room_filelist):
        if test_area in room_name:
            test_idxs.append(i)
        else:
            train_idxs.append(i)

    train_data = data_batches[train_idxs, ...]
    train_label = label_batches[train_idxs]
    test_data = data_batches[test_idxs, ...]
    test_label = label_batches[test_idxs]
    print('train_data',train_data.shape,'train_label' ,train_label.shape)
    print('test_data',test_data.shape,'test_label', test_label.shape)
    return train_data,train_label,test_data,test_label


In [6]:
class S3DISDataLoader(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.labels[index]


In [7]:
train_data, train_label, test_data, test_label = recognize_all_data(test_area = 5)

train_data (16733, 4096, 9) train_label (16733, 4096)
test_data (6852, 4096, 9) test_label (6852, 4096)


In [8]:
dataset = S3DISDataLoader(train_data,train_label)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=24,
                                             shuffle=True, num_workers=int(4))
test_dataset = S3DISDataLoader(test_data,test_label)
testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=8,
                                                 shuffle=True, num_workers=int(4))

In [9]:
import torch.nn as nn
import torch.nn.functional as F
from time import time

def timeit(tag, t):
    print("{}: {}s".format(tag, time() - t))
    return time()

In [10]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.

    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst

    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

In [11]:
def index_points(points, idx):
    """

    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

In [12]:
def farthest_point_sample(xyz, npoint):
    """
    Input:
        xyz: pointcloud data, [B, N, C]
        npoint: number of samples
    Return:
        centroids: sampled pointcloud index, [B, npoint]
    """
    device = xyz.device
    B, N, C = xyz.shape
    centroids = torch.zeros(B, npoint, dtype=torch.long).to(device)
    distance = torch.ones(B, N).to(device) * 1e10
    farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device)
    batch_indices = torch.arange(B, dtype=torch.long).to(device)
    for i in range(npoint):
        centroids[:, i] = farthest
        centroid = xyz[batch_indices, farthest, :].view(B, 1, 3)
        dist = torch.sum((xyz - centroid) ** 2, -1)
        mask = dist < distance
        distance[mask] = dist[mask]
        farthest = torch.max(distance, -1)[1]
    return centroids

In [13]:
def query_ball_point(radius, nsample, xyz, new_xyz):
    """
    Input:
        radius: local region radius
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    device = xyz.device
    B, N, C = xyz.shape
    _, S, _ = new_xyz.shape
    group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1])
    sqrdists = square_distance(new_xyz, xyz)
    group_idx[sqrdists > radius ** 2] = N
    group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample]
    group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample])
    mask = group_idx == N
    group_idx[mask] = group_first[mask]
    return group_idx

In [14]:
def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, C]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, C]
        new_points: sampled points data, [B, 1, N, C+D]
    """
    B, N, C = xyz.shape
    S = npoint
    fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C]
    new_xyz = index_points(xyz, fps_idx)
    idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
    if points is not None:
        grouped_points = index_points(points, idx)
        fps_points = index_points(points, fps_idx)
        fps_points = torch.cat([new_xyz, fps_points], dim=-1)
        new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D]
    else:
        new_points = grouped_xyz_norm
        fps_points = new_xyz
    if returnfps:
        return new_xyz, new_points, grouped_xyz, fps_points
    else:
        return new_xyz, new_points

In [15]:
def sample_and_group_all(xyz, points):
    """
    Input:
        xyz: input points position data, [B, N, C]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, 1, C]
        new_points: sampled points data, [B, 1, N, C+D]
    """
    device = xyz.device
    B, N, C = xyz.shape
    new_xyz = torch.zeros(B, 1, C).to(device)
    grouped_xyz = xyz.view(B, 1, N, C)
    if points is not None:
        new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1)
    else:
        new_points = grouped_xyz
    return new_xyz, new_points

class GraphAttention(nn.Module):
    def __init__(self,all_channel,feature_dim,dropout,alpha):
        super(GraphAttention, self).__init__()
        self.alpha = alpha
        self.a = nn.Parameter(torch.zeros(size=(all_channel, feature_dim)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        self.dropout = dropout
        self.leakyrelu = nn.LeakyReLU(self.alpha)

    def forward(self, center_xyz, center_feature, grouped_xyz, grouped_feature):
        '''
        Input:
            center_xyz: sampled points position data [B, npoint, C]
            center_feature: centered point feature [B, npoint, D]
            grouped_xyz: group xyz data [B, npoint, nsample, C]
            grouped_feature: sampled points feature [B, npoint, nsample, D]
        Return:
            graph_pooling: results of graph pooling [B, npoint, D]
        '''
        B, npoint, C = center_xyz.size()
        _, _, nsample, D = grouped_feature.size()
        delta_p = center_xyz.view(B, npoint, 1, C).expand(B, npoint, nsample, C) - grouped_xyz # [B, npoint, nsample, C]
        delta_h = center_feature.view(B, npoint, 1, D).expand(B, npoint, nsample, D) - grouped_feature # [B, npoint, nsample, D]
        delta_p_concat_h = torch.cat([delta_p,delta_h],dim = -1) # [B, npoint, nsample, C+D]
        e = self.leakyrelu(torch.matmul(delta_p_concat_h, self.a)) # [B, npoint, nsample,D]
        attention = F.softmax(e, dim=2) # [B, npoint, nsample,D]
        attention = F.dropout(attention, self.dropout, training=self.training)
        graph_pooling = torch.sum(torch.mul(attention, grouped_feature),dim = 2) # [B, npoint, D]
        return graph_pooling


class GraphAttentionConvLayer(nn.Module):
    def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all,droupout=0.6,alpha=0.2):
        super(GraphAttentionConvLayer, self).__init__()
        self.npoint = npoint
        self.radius = radius
        self.nsample = nsample
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        self.droupout = droupout
        self.alpha = alpha
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm2d(out_channel))
            last_channel = out_channel
        self.group_all = group_all
        self.GAT = GraphAttention(3+last_channel,last_channel,self.droupout,self.alpha)

    def forward(self, xyz, points):
        """
        Input:
            xyz: input points position data, [B, C, N]
            points: input points data, [B, D, N]
        Return:
            new_xyz: sampled points position data, [B, C, S]
            new_points_concat: sample points feature data, [B, D', S]
        """
        xyz = xyz.permute(0, 2, 1)
        if points is not None:
            points = points.permute(0, 2, 1)

        if self.group_all:
            new_xyz, new_points = sample_and_group_all(xyz, points)
        else:
            new_xyz, new_points, grouped_xyz, fps_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points, True)
        # new_xyz: sampled points position data, [B, npoint, C]
        # new_points: sampled points data, [B, npoint, nsample, C+D]
        # fps_points: [B, npoint, C+D,1]
        new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint]
        fps_points = fps_points.unsqueeze(3).permute(0, 2, 3, 1) # [B, C+D, 1,npoint]
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            fps_points = F.relu(bn(conv(fps_points)))
            new_points =  F.relu(bn(conv(new_points)))
        # new_points: [B, F, nsample,npoint]
        # fps_points: [B, F, 1,npoint]
        new_points = self.GAT(center_xyz=new_xyz,
                              center_feature=fps_points.squeeze().permute(0,2,1),
                              grouped_xyz=grouped_xyz,
                              grouped_feature=new_points.permute(0,3,2,1))
        new_xyz = new_xyz.permute(0, 2, 1)
        new_points = new_points.permute(0, 2, 1)
        return new_xyz, new_points

class PointNetFeaturePropagation(nn.Module):
    def __init__(self, in_channel, mlp):
        super(PointNetFeaturePropagation, self).__init__()
        self.mlp_convs = nn.ModuleList()
        self.mlp_bns = nn.ModuleList()
        last_channel = in_channel
        for out_channel in mlp:
            self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
            self.mlp_bns.append(nn.BatchNorm1d(out_channel))
            last_channel = out_channel

    def forward(self, xyz1, xyz2, points1, points2):
        """
        Input:
            xyz1: input points position data, [B, C, N]
            xyz2: sampled input points position data, [B, C, S]
            points1: input points data, [B, D, N]
            points2: input points data, [B, D, S]
        Return:
            new_points: upsampled points data, [B, D', N]
        """
        xyz1 = xyz1.permute(0, 2, 1)
        xyz2 = xyz2.permute(0, 2, 1)

        points2 = points2.permute(0, 2, 1)
        B, N, C = xyz1.shape
        _, S, _ = xyz2.shape

        if S == 1:
            interpolated_points = points2.repeat(1, N, 1)
        else:
            dists = square_distance(xyz1, xyz2)
            dists, idx = dists.sort(dim=-1)
            dists, idx = dists[:, :, :3], idx[:, :, :3]  # [B, N, 3]
            dists[dists < 1e-10] = 1e-10
            weight = 1.0 / dists  # [B, N, 3]
            weight = weight / torch.sum(weight, dim=-1).view(B, N, 1)  # [B, N, 3]
            interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)

        if points1 is not None:
            points1 = points1.permute(0, 2, 1)
            new_points = torch.cat([points1, interpolated_points], dim=-1)
        else:
            new_points = interpolated_points

        new_points = new_points.permute(0, 2, 1)
        for i, conv in enumerate(self.mlp_convs):
            bn = self.mlp_bns[i]
            new_points =  F.relu(bn(conv(new_points)))
        return new_points

In [16]:
class GACNet(nn.Module):
    def __init__(self, num_classes,droupout=0,alpha=0.2):
        super(GACNet, self).__init__()
        # GraphAttentionConvLayer: npoint, radius, nsample, in_channel, mlp, group_all,droupout,alpha
        self.sa1 = GraphAttentionConvLayer(1024, 0.1, 32, 6 + 3, [32, 32, 64], False, droupout,alpha)
        self.sa2 = GraphAttentionConvLayer(256, 0.2, 32, 64 + 3, [64, 64, 128], False, droupout,alpha)
        self.sa3 = GraphAttentionConvLayer(64, 0.4, 32, 128 + 3, [128, 128, 256], False, droupout,alpha)
        self.sa4 = GraphAttentionConvLayer(16, 0.8, 32, 256 + 3, [256, 256, 512], False, droupout,alpha)
        # self.sa1 = GraphAttentionConvLayer(1024, 0.1, 32, 6 + 3, [32, 32], False, droupout,alpha)
        # self.sa2 = GraphAttentionConvLayer(256, 0.2, 32, 32 + 3, [64, 64], False, droupout,alpha)
        # self.sa3 = GraphAttentionConvLayer(64, 0.4, 32, 64 + 3, [128, 128], False, droupout,alpha)
        # self.sa4 = GraphAttentionConvLayer(16, 0.8, 32, 128 + 3, [300, 300, 300, 300], False, droupout,alpha)
        # PointNetFeaturePropagation: in_channel, mlp
        # self.fp4 = PointNetFeaturePropagation(768, [256, 256])
        # self.fp3 = PointNetFeaturePropagation(384, [256, 256])
        # self.fp2 = PointNetFeaturePropagation(320, [256, 128])
        # self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128])
        self.fp4 = PointNetFeaturePropagation(768, [128, 128])
        self.fp3 = PointNetFeaturePropagation(256, [64, 64])
        self.fp2 = PointNetFeaturePropagation(768, [32, 32])
        self.fp1 = PointNetFeaturePropagation(768, [300, 300, 300, 300])

        # self.conv1 = nn.Conv1d(128, 128, 1)
        self.conv1 = nn.Conv1d(128, 128, 1)
        self.bn1 = nn.BatchNorm1d(128)
        self.drop1 = nn.Dropout(droupout)
        self.conv2 = nn.Conv1d(128, num_classes, 1)

    def forward(self, xyz, point):
        l1_xyz, l1_points = self.sa1(xyz, point)
        l2_xyz, l2_points = self.sa2(l1_xyz, l1_points)
        l3_xyz, l3_points = self.sa3(l2_xyz, l2_points)
        l4_xyz, l4_points = self.sa4(l3_xyz, l3_points)
        
        l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points)
        l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points)
        # print(l2_points.size())
        # l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points)
        # l0_points = self.fp1(xyz, l1_xyz, None, l1_points)
        # print(l0_points.size())

        x = self.drop1(F.relu(self.bn1(self.conv1(l0_points))))
        x = self.conv2(x)
        x = F.log_softmax(x, dim=1)
        x = x.permute(0, 2, 1)
        return x

In [17]:
from collections import defaultdict

num_classes = 13
blue = lambda x: '\033[94m' + x + '\033[0m'
model = GACNet(num_classes,0,0.2)
init_epoch = 0

def adjust_learning_rate(optimizer, step):
    """Sets the learning rate to the initial LR decayed by 30 every 20000 steps"""
    lr = 0.01 * (0.3 ** (step // 20000))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

optimizer = torch.optim.Adam(
            model.parameters(),
            lr=0.01,
            betas=(0.9, 0.999),
            eps=1e-08,
            weight_decay=1e-4)

# model.cuda()

history = defaultdict(lambda: list())
best_acc = 0
best_meaniou = 0
step = 0

In [18]:
# *_*coding:utf-8 *_*
#@title Utils here
import os
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
from tqdm import tqdm
from collections import defaultdict
import datetime
import pandas as pd
import torch.nn.functional as F
def to_categorical(y, num_classes):
    """ 1-hot encodes a tensor """
    new_y = torch.eye(num_classes)[y.cpu().data.numpy(),]
    if (y.is_cuda):
        return new_y.cuda()
    return new_y

def show_example(x, y, x_reconstruction, y_pred,save_dir, figname):
    x = x.squeeze().cpu().data.numpy()
    x = x.permute(0,2,1)
    y = y.cpu().data.numpy()
    x_reconstruction = x_reconstruction.squeeze().cpu().data.numpy()
    _, y_pred = torch.max(y_pred, -1)
    y_pred = y_pred.cpu().data.numpy()

    fig, ax = plt.subplots(1, 2)
    ax[0].imshow(x, cmap='Greys')
    ax[0].set_title('Input: %d' % y)
    ax[1].imshow(x_reconstruction, cmap='Greys')
    ax[1].set_title('Output: %d' % y_pred)
    plt.savefig(save_dir + figname + '.png')

def save_checkpoint(epoch, train_accuracy, test_accuracy, model, optimizer, path):
    savepath  = path + '/checkpoint-%f-%04d.pth' % (test_accuracy, epoch)
    state = {
        'epoch': epoch,
        'train_accuracy': train_accuracy,
        'test_accuracy': test_accuracy,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(state, savepath)

def test(model, loader):
    metrics = defaultdict(lambda:list())
    hist_acc = []
    for batch_id, (x, y) in tqdm(enumerate(loader), total=len(loader),smoothing=0.9):
        x = x.float().cuda()
        y = y.long().cuda()
        x = x.permute(0,2,1)
        y_pred,_ = model(x)
        # show_example(x, y, x, y_pred, './images', "fig")
        _, y_pred = torch.max(y_pred, -1)
        pred_choice = y_pred.data.max(1)[1]
        correct = pred_choice.eq(y.data).cpu().sum()
        metrics['accuracy'].append(correct.data)
        
    hist_acc.append(np.mean(metrics['accuracy']))
    metrics['accuracy'] = np.mean(metrics['accuracy'])
    return metrics, hist_acc

def compute_iou(pred,target,iou_tabel=None):
    ious = []
    target = target.cpu().data.numpy()
    for j in range(pred.size(0)):
        batch_pred = pred[j]
        batch_target = target[j]
        batch_choice = batch_pred.data.max(1)[1].cpu().data.numpy()
        for cat in np.unique(batch_target):
            intersection = np.sum((batch_target == cat) & (batch_choice == cat))
            union = float(np.sum((batch_target == cat) | (batch_choice == cat)))
            iou = intersection/union
            ious.append(iou)
            iou_tabel[cat,0] += iou
            iou_tabel[cat,1] += 1
    return np.mean(ious), iou_tabel

def test_seg(model, loader, catdict, num_classes = 13):
    ''' catdict = {0:Airplane, 1:Airplane, ...49:Table} '''
    iou_tabel = np.zeros((len(catdict),3))
    metrics = defaultdict(lambda:list())
    hist_acc = []
    for batch_id, (points, target) in tqdm(enumerate(loader), total=len(loader), smoothing=0.9):
        batchsize, num_point, _ = points.size()
        points, target = Variable(points.float()), Variable(target.long())
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        pred = model(points[:,:3,:],points[:,3:,:])
        mean_iou, iou_tabel = compute_iou(pred,target,iou_tabel)
        # print(points[:,:3,:])
        pred = pred.contiguous().view(-1, num_classes)
        target = target.view(-1, 1)[:, 0]
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.data).cpu().sum()
        metrics['accuracy'].append(correct.item()/ (batchsize * num_point))
        metrics['iou'].append(mean_iou)
    iou_tabel[:,2] = iou_tabel[:,0] /(iou_tabel[:,1]+0.01)
    hist_acc += metrics['accuracy']
    metrics['accuracy'] = np.mean(metrics['accuracy'])
    iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou'])
    iou_tabel['Category_IOU'] = [cat_value for cat_value in catdict.values()]
    cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean()

    return metrics, hist_acc, cat_iou

def compute_avg_curve(y, n_points_avg):
    avg_kernel = np.ones((n_points_avg,)) / n_points_avg
    rolling_mean = np.convolve(y, avg_kernel, mode='valid')
    return rolling_mean

def plot_loss_curve(history,n_points_avg,n_points_plot,save_dir):
    curve = np.asarray(history['loss'])[-n_points_plot:]
    avg_curve = compute_avg_curve(curve, n_points_avg)
    plt.plot(avg_curve, '-g')

    curve = np.asarray(history['margin_loss'])[-n_points_plot:]
    avg_curve = compute_avg_curve(curve, n_points_avg)
    plt.plot(avg_curve, '-b')

    curve = np.asarray(history['reconstruction_loss'])[-n_points_plot:]
    avg_curve = compute_avg_curve(curve, n_points_avg)
    plt.plot(avg_curve, '-r')

    plt.legend(['Total Loss', 'Margin Loss', 'Reconstruction Loss'])
    plt.savefig(save_dir + '/'+ str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M')) + '_total_result.png')
    plt.close()

def plot_acc_curve(total_train_acc,total_test_acc,save_dir):
    plt.plot(total_train_acc, '-b',label = 'train_acc')
    plt.plot(total_test_acc, '-r',label = 'test_acc')
    plt.legend()
    plt.ylabel('acc')
    plt.xlabel('epoch')
    plt.title('Accuracy of training and test')
    plt.savefig(save_dir +'/'+ str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M'))+'_total_acc.png')
    plt.close()

def show_point_cloud(tuple,seg_label=[],title=None):
    import matplotlib.pyplot as plt
    if seg_label == []:
        x = [x[0] for x in tuple]
        y = [y[1] for y in tuple]
        z = [z[2] for z in tuple]
        ax = plt.subplot(111, projection='3d')
        ax.scatter(x, y, z, c='b', cmap='spectral')
        ax.set_zlabel('Z')
        ax.set_ylabel('Y')
        ax.set_xlabel('X')
    else:
        category = list(np.unique(seg_label))
        color = ['b','r','g','y','w','b','p']
        ax = plt.subplot(111, projection='3d')
        for categ_index in range(len(category)):
            tuple_seg = tuple[seg_label == category[categ_index]]
            x = [x[0] for x in tuple_seg]
            y = [y[1] for y in tuple_seg]
            z = [z[2] for z in tuple_seg]
            ax.scatter(x, y, z, c=color[categ_index], cmap='spectral')
        ax.set_zlabel('Z')
        ax.set_ylabel('Y')
        ax.set_xlabel('X')
    plt.title(title)
    plt.show()

In [23]:
from torch.autograd import Variable
from tqdm import tqdm
import logging
from pathlib import Path
import datetime

experiment_dir = Path('./experiment/')
print(experiment_dir)
experiment_dir.mkdir(exist_ok=True)
file_dir = Path(str(experiment_dir) +'/'+ str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')))
file_dir.mkdir(exist_ok=True)
checkpoints_dir = file_dir.joinpath('./checkpoints/')
checkpoints_dir.mkdir(exist_ok=True)
log_dir = file_dir.joinpath('./logs/')
log_dir.mkdir(exist_ok=True)

logger = logging.getLogger('GACNet')
logger.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(str('./logs/') + '/train_GACNet.txt')
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)

# model.cuda()
# model.load_state_dict(torch.load('./experiment/2023-05-20_22-47/checkpoints/GACNet_000_0.7575.pth'))

seg_classes = class2label
seg_label_to_cat = {}
for i,cat in enumerate(seg_classes.keys()):
    seg_label_to_cat[i] = cat

for epoch in range(init_epoch,2):
    for i, data in tqdm(enumerate(dataloader, 0),total=len(dataloader),smoothing=0.9):
        points, target = data
        points, target = Variable(points.float()), Variable(target.long())
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        optimizer.zero_grad()
        model = model.train()
        pred = model(points[:,:3,:],points[:,3:,:])
        pred = pred.contiguous().view(-1, num_classes)
        target = target.view(-1, 1)[:, 0]
        loss = F.nll_loss(pred, target)
        history['loss'].append(loss.cpu().data.numpy())
        loss.backward()
        optimizer.step()
        step += 1
        adjust_learning_rate(optimizer, step)
    
    test_metrics, test_hist_acc, cat_mean_iou = test_seg(model, testdataloader, seg_label_to_cat)
    mean_iou = np.mean(cat_mean_iou)

    print('Epoch %d  %s accuracy: %f  meanIOU: %f' % (
              epoch, blue('test'), test_metrics['accuracy'],mean_iou))
    logger.info('Epoch %d  %s accuracy: %f  meanIOU: %f' % (
              epoch, 'test', test_metrics['accuracy'],mean_iou))
    if test_metrics['accuracy'] > best_acc:
        best_acc = test_metrics['accuracy']
        torch.save(model.state_dict(), '%s/GACNet_%.3d_%.4f.pth' % (checkpoints_dir, epoch, best_acc))
        logger.info(cat_mean_iou)
        logger.info('Save model..')
        print('Save model..')
        print(cat_mean_iou)
    if mean_iou > best_meaniou:
        best_meaniou = mean_iou
    print('Best accuracy is: %.5f'%best_acc)
    logger.info('Best accuracy is: %.5f'%best_acc)
    print('Best meanIOU is: %.5f'%best_meaniou)
    logger.info('Best meanIOU is: %.5f'%best_meaniou)

experiment



KeyboardInterrupt

