In [1]:
import provider
import tensorflow as tf 
import tf_ops.grouping.tf_grouping as grouping
import numpy as np
import pandas as pd 
import matplotlib.pyplot as plt
import tensorflow.keras as keras
import tensorflow.keras.layers as layers 
import time
# import dataset
import glob
import pointnet_util as utils


In [2]:
TRAIN_FILES=glob.glob("modelnet40_ply_hdf5_2048/*train*.h5")
TEST_FILES=glob.glob("modelnet40_ply_hdf5_2048/*test*.h5")

In [3]:
class self_attention(keras.Model):
    def __init__(self,out_dim,k_head):
       super(self_attention,self).__init__()
       self.out_dim=out_dim
       self.k=k_head
    def build(self,input_shape):
        self.convq=[layers.Conv2D(self.out_dim,1,1) for i in range(self.k)]
        self.convk=[layers.Conv2D(self.out_dim,1,1) for i in range(self.k)]
        self.convv=[layers.Conv2D(self.out_dim,1,1) for i in range(self.k)]
        self.Dense_out=layers.Dense(self.out_dim,activation="relu")
        self.norm=keras.layers.BatchNormalization()
        self.soft=keras.layers.Softmax()
    def call(self,input_feature,training=True):
        out=[]
        for i in range(self.k):
            q=self.convq[i](input_feature)
            v=self.convv[i](input_feature)
            k=self.convk[i](input_feature)
            q=tf.reduce_mean(q,axis=-2,keepdims=True)
            # q=tf.tile(q,[1,1,tf.shape(v)[2],1])
            k=tf.transpose(k,[0,1,3,2])
            qk=tf.matmul(q,k)/tf.sqrt(tf.cast(self.out_dim,tf.float32))
            out.append(tf.matmul(self.soft(qk),v))
        out=tf.concat(out,axis=-1)
        out=self.norm(self.Dense_out(tf.squeeze(out,axis=-2)),training=training)
        return  out 


In [4]:
def random_sample(xyz):
    num=tf.shape(xyz)[1]
    return xyz[:,0:tf.cast(num/2,tf.int32),:]


In [5]:
class pointcloud_class(keras.Model):
    def __init__(self,class_num=10):
        super(pointcloud_class,self).__init__()
        self.class_num=class_num
    def build(self,inputshape):
        self.n=7
        self.Dense1=keras.layers.Dense(64,activation="relu") 
        self.Dense2=keras.layers.Dense(128,activation="relu") 
        self.norm1=keras.layers.BatchNormalization()
        self.norm2=keras.layers.BatchNormalization()
        self.self_attention=[self_attention(128,1) for i in range(self.n+1)]
        self.Dense=[layers.Dense(128,activation="relu") for _ in range(self.n)]
        self.Dense3=layers.Dense(self.class_num)
        self.soft=keras.layers.Softmax()
    def call(self , points_xyz,training=True):
        # featrue=tf.concat([points_xyz,Density],axis=-1)
        featrue=self.Dense1(points_xyz)
        featrue=self.norm1(featrue,training=training)
        featrue=self.Dense2(featrue)
        featrue=self.norm2(featrue,training=training)
        old_xyz=points_xyz
        for i in range(self.n):
            new_xyz = random_sample(old_xyz)
            new_xyz, grouped_featur, idx, grouped_xyz=utils.
            grouped_feature=grouping.group_point(featrue,idx)
            new_xyz=tf.reduce_mean(grouped_xyz,axis=-2)
            local_fature=grouped_feature-tf.reduce_mean(grouped_feature,axis=-2,keepdims=True)
            global_fature=tf.reduce_max(grouped_featur,axis=-2,keepdims=False)
            local_featrue=self.self_attention[i](local_fature,training=training)
            featrue=self.Dense[i](tf.concat([local_featrue,global_fature],axis=-1))
            old_xyz=new_xyz
        grouped_xyz, new_points, idx=utils.grouping_all(featrue,new_xyz)
        out=self.self_attention[self.n](new_points,training=training)
        out=tf.squeeze(out,axis=1)
        out=self.Dense3(out)
        out=self.soft(out)
        return out


In [6]:
lr_schedule=keras.optimizers.schedules.ExponentialDecay(0.001,100000,0.7)

In [7]:

optimizer=keras.optimizers.Adam(0.0001)
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False)
model=pointcloud_class(40)
mtric1=keras.metrics.SparseCategoricalAccuracy()
mtric2=keras.metrics.SparseCategoricalCrossentropy()

In [8]:
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False)

In [9]:
LOG_FOUT = open('log_train.txt', 'w')
def log_string(out_str):
    LOG_FOUT.write(out_str+'\n')
    LOG_FOUT.flush()
    print(out_str)


In [10]:
# def shuffle(data):
#     shuffed_data= np.zeros(data.shape, dtype=np.float32)
#     for k in range(data.shape[0]):
#         shuffed_data[k,...]=np.random.shuffle(data[k,...])
#     return shuffed_data
@tf.function
def callmodel(input,current_label):
    with tf.GradientTape() as tape:
            logits=model(input,training=True)
            if model.losses :
                regularization_loss=tf.math.add_n(model.losses)
            else:
                regularization_loss=0
            loss_value = loss(current_label, logits)+regularization_loss
            # print(loss_value)
    grads=tape.gradient(loss_value,model.trainable_variables)
    optimizer.apply_gradients(zip(grads,model.trainable_variables))
    mtric1.update_state(current_label, logits)
    mtric2.update_state(current_label, logits)
@tf.function
def evaluate_model(input,current_label):
    logits=model(input,training=False)
    mtric1.update_state(current_label, logits)
    mtric2.update_state(current_label, logits)



In [11]:
def shuffledata(data):
    idx=np.arange(data.shape[1])
    for k in range(data.shape[0]):
        np.random.shuffle(idx)
        data[k,...]=data[k,idx,:]
    return data
        

In [12]:

def train_one_epoch(epoch):
    # train_file_idxs = np.arange(0, len(TRAIN_FILES))
    # np.random.shuffle(train_file_idxs)
    NUM_POINT=1024
    BATCH_SIZE=128
    global BATCH
    for step,(current_data,current_label) in enumerate(datasets):
        # print(current_data.shape[0])
        BATCH=BATCH+1
        ckpt.batch.assign_add(1)
        # Augment batched point clouds by rotation and jittering
        rotated_data = provider.rotate_point_cloud(current_data.numpy())
        jittered_data = provider.jitter_point_cloud(rotated_data)
        jittered_data= jittered_data.astype(np.float32)
        jittered_data=shuffledata(jittered_data)
        # jittered_data=shuffle(jittered_data)
        # print(jittered_data.dtype)
        # print(jittered_data.shape)
        tf.keras.backend.set_value(optimizer.lr, lr_schedule(BATCH*BATCH_SIZE))
        callmodel(jittered_data,current_label)
        # pred_val = np.argmax(logits, 1)
        # correct = np.sum(pred_val == current_label[start_idx:end_idx])
        # total_correct += correct
        # total_seen += BATCH_SIZE
        # # print(loss_value)
        # loss_sum += float(loss_value)
def test_one_epoch(epoch):
    for setp ,(data,label) in enumerate(test_datasets):
        evaluate_model(data,label)

        

In [13]:
BATCH=1

In [14]:
ckpt=tf.train.Checkpoint(model=model,opti=optimizer,batch=tf.Variable(1))
ckpt_mana=tf.train.CheckpointManager(ckpt,"pointcloud_class",max_to_keep=3)
epochs=200



In [15]:
data=[];label=[]

for i in range(len(TRAIN_FILES)):
    data1,label1=provider.loadDataFile(TRAIN_FILES[i])
    data.append(data1[:,0:1024,:])
    label.append(label1)
data=np.concatenate(data,axis=0)
label=np.concatenate(label,axis=0)
datasets=tf.data.Dataset.from_tensor_slices((data,label))
datasets=datasets.shuffle(buffer_size=1024)
func=lambda x,y:(tf.random.shuffle(x) ,y)
datasets=datasets.map(func,num_parallel_calls=tf.data.experimental.AUTOTUNE)
datasets=datasets.batch(256)
# datasets.cache()

datasets=datasets.prefetch(1)

data=[];label=[]

for i in range(len(TEST_FILES)):
    data1,label1=provider.loadDataFile(TEST_FILES[i])
    data.append(data1[:,0:1024,:])
    label.append(label1)
data=np.concatenate(data,axis=0)
label=np.concatenate(label,axis=0)
test_datasets=tf.data.Dataset.from_tensor_slices((data,label))
test_datasets=test_datasets.shuffle(buffer_size=1024)
func=lambda x,y:(tf.random.shuffle(x) ,y)
test_datasets=test_datasets.map(func,num_parallel_calls=tf.data.experimental.AUTOTUNE)
test_datasets=test_datasets.batch(256)
# datasets.cache()
test_datasets=test_datasets.prefetch(1)







In [16]:
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    train_one_epoch(epoch)
    ckpt_mana.save()
    # print(lr_schedule)
    log_string("-----{}------".format(epoch))
    log_string('train_mean loss: %f' % float(mtric2.result())) 
    log_string('train_accuracy: %f' % float(mtric1.result()))
    mtric1.reset_states()
    mtric2.reset_states()
    test_one_epoch(epoch)
    log_string('test_mean loss: %f' % float(mtric2.result()))
    log_string('test_accuracy: %f' % float(mtric1.result()))



Start of epoch 0


In [1]:
a=1