In [1]:
import numpy as np
import math
import random
import os
import torch
import scipy.spatial.distance 
from torch.utils.data import Dataset, DataLoader 
from torchvision import transforms, utils 
import torch.nn.functional as F
from time import time
import torch.nn as nn
import math
import copy



In [2]:
random.seed=42

In [3]:
path=os.path.join("/home/parvez","Dataset/ModelNet40_dataset/ModelNet40")
print(path)   

/home/parvez/Dataset/ModelNet40_dataset/ModelNet40


In [4]:
folders=[dir for dir in sorted(os.listdir(path))]
classes={folder:i for i, folder in enumerate(folders)}


In [5]:
def read_off(file):
    off_header = file.readline().strip()
    if 'OFF' == off_header:
        n_verts, n_faces, __ = tuple([int(s) for s in file.readline().strip().split(' ')])
    else:
        n_verts, n_faces, __ = tuple([int(s) for s in off_header[3:].split(' ')])
    verts = [[float(s) for s in file.readline().strip().split(' ')] for i_vert in range(n_verts)]
    faces = [[int(s) for s in file.readline().strip().split(' ')][1:] for i_face in range(n_faces)]
    return verts, faces


In [6]:
with open(os.path.join(path, 'bed/train/bed_0001.off'),'r') as f:
    verts, faces=read_off(f)
      

In [7]:
i,j,k=np.array(faces).T
x,y,z=np.array(verts).T

In [8]:
class PointSampler(object):
    def __init__(self, output_size):
        assert isinstance(output_size,int)
        self.output_size=output_size
        
    def triangle_area(self, pt1, pt2, pt3):
        side_a=np.linalg.norm(pt1-pt2)
        side_b=np.linalg.norm(pt2-pt3)
        side_c=np.linalg.norm(pt3-pt1)
        s=0.5*(side_a+side_b+side_c)
        return max(s*(s-side_a)*(s-side_b)*(s-side_c),0)**0.5
    
    def sample_point(self, pt1,pt2,pt3):
        s,t=sorted([random.random(), random.random()])
        f=lambda i: s*pt1[i]+(t-s)*pt2[i]+(1-t)*pt3[i]
        return (f(0), f(1), f(2))
    
    def __call__(self, mesh):
        verts, faces=mesh
        verts=np.array(verts)
        areas=np.zeros((len(faces)))
        for i in range(len(areas)):
            areas[i]=(self.triangle_area(verts[faces[i][0]],
                                       verts[faces[i][1]],
                                       verts[faces[i][2]]))
            
        sampled_faces=(random.choices(faces,
                                     weights=areas,
                                     cum_weights=None,
                                     k=self.output_size))
        
        sampled_points=np.zeros((self.output_size, 3))
        for i in range(len(sampled_faces)):
            sampled_points[i]=(self.sample_point(verts[sampled_faces[i][0]],
                                                verts[sampled_faces[i][1]],
                                                verts[sampled_faces[i][2]]))
            
        return sampled_points
    
    

In [9]:
pointcloud=PointSampler(3000)((verts, faces))

In [10]:
print(pointcloud.shape)

(3000, 3)


In [11]:
class Normalize(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2
        
        norm_pointcloud=pointcloud -  np.mean(pointcloud, axis=0)
        norm_pointcloud /=np.max(np.linalg.norm(norm_pointcloud, axis=1))
        
        return norm_pointcloud
    


In [12]:
norm_pointcloud = Normalize()(pointcloud)

In [13]:
class ToTensor(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2
        return torch.from_numpy(pointcloud)
    

In [14]:
def default_transforms():
    return transforms.Compose([
        PointSampler(1024),
        Normalize(),
        ToTensor()
    ])



In [15]:
class PointCloudData(Dataset):
    def __init__(self,root_dir, valid=False, folder='train', transform=default_transforms()):
        self.root_dir=root_dir
        folders=[dir for dir in sorted(os.listdir(root_dir))]
        self.classes={folder:i for i, folder in enumerate(folders)}
        self.transforms=transform if not valid else default_transforms()
        self.valid=valid
        self.files=[]
        for category in self.classes.keys():
            new_dir=os.path.join(root_dir, category, folder)
            for file in os.listdir(new_dir):
                if file.endswith('.off'):
                    sample={}
                    sample['pcd_path']=os.path.join(new_dir, file)
                    sample['category']=category
                    self.files.append(sample)
                    
    
    def __len__(self):
        return len(self.files)
    
    
    def __preproc__(self, file):
        verts, faces=read_off(file)
        if self.transforms:
            pointcloud=self.transforms((verts, faces))
        
        return pointcloud 
    
    def __getitem__(self, idx):
        pcd_path=self.files[idx]['pcd_path']
        category=self.files[idx]['category']
        with open(pcd_path, 'r') as f:
            pointcloud=self.__preproc__(f)
        return {'pointcloud' : pointcloud,
                'category' : self.classes[category]}
    
    
    

In [16]:
train_ds=PointCloudData(path)
valid_ds=PointCloudData(path, valid=True, folder='test')


In [17]:
train_loader=DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
valid_loader=DataLoader(dataset=valid_ds, batch_size=64)


In [18]:
def fps(pcd,n_samples):
    B,N,D=pcd.shape
    kernel=torch.zeros(B,n_samples,D)
    
    
    for n in range(B):
        
        points=np.array(pcd[n])
        
        points_left=np.arange(len(points))
        
        sample_inds=np.zeros(n_samples,dtype='int')
        
        dists=np.ones_like(points_left)*float('inf')
        
        selected=0
        sample_inds[0]=points_left[selected]
        
        points_left=np.delete(points_left,selected)
        
        for i in range(1, n_samples):
            
            last_added=sample_inds[i-1]
            
            dist_to_last_added_point =((points[last_added]-points[points_left])**2).sum(-1)
            
            dists[points_left]=np.minimum(dist_to_last_added_point, dists[points_left])
            
            selected=np.argmax(dists[points_left])
            sample_inds[i]=points_left[selected]
            
            points_left=np.delete(points_left, selected)
            
        kernel[n]=torch.from_numpy(points[sample_inds])
    
    return kernel


In [19]:
def knn(x,k):
    inner=-2*torch.matmul(x.transpose(2,1).contiguous(),x)
    xx=torch.sum(x**2,dim=1,keepdim=True)
    pairwise_distance=-xx-inner-xx.transpose(2,1).contiguous()
    
    idx=pairwise_distance.topk(k=k,dim=-1)[1]
    
    return idx    

In [20]:
def get_graph_feature(x,k=20):
    
    idx=knn(x, k=k) # (batch_size, num_points, k)
    batch_size, num_points, _ =idx.size()
    
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    idx_base=torch.arange(0,batch_size).view(-1,1,1)*num_points
    
    idx=idx+idx_base
    
    idx=idx.view(-1)
    
    _,num_dims,_=x.size()
    
    x=x.transpose(2,1).contiguous()
    
    feature=x.view(batch_size*num_points,-1)[idx,:]
    feature=feature.view(batch_size,num_points,k,num_dims)
    x=x.view(batch_size, num_points, 1,num_dims).repeat(1,1,k,1)
    
    feature=torch.cat((feature,x),dim=3).permute(0,3,1,2)
    
    
    return feature

In [21]:
class DGCNN(nn.Module):
    def __init__(self,emb_dims=1024, input_shape="bnc"):
        super(DGCNN,self).__init__()
        if input_shape not in ["bcn","bnc"]:
            raise ValueError("Allowed shapes are 'bcn' (batch*channels*num_in_points),'bnc' ")
        self.input_shape=input_shape
        self.emb_dims=emb_dims
        
        self.conv1=nn.Conv2d(6,64,kernel_size=1,bias=False)
        self.conv2=nn.Conv2d(64,64,kernel_size=1,bias=False)
        self.conv3=nn.Conv2d(64,128,kernel_size=1,bias=False)
        self.conv4=nn.Conv2d(128,256,kernel_size=1,bias=False)
        self.conv5=nn.Conv2d(512,emb_dims,kernel_size=1, bias=False)
        self.bn1=nn.BatchNorm2d(64)
        self.bn2=nn.BatchNorm2d(64)
        self.bn3=nn.BatchNorm2d(128)
        self.bn4=nn.BatchNorm2d(256)
        self.bn5=nn.BatchNorm2d(emb_dims)
        
    def forward(self,input_data):
        if self.input_shape=="bnc":
            input_data=input_data.permute(0,2,1)
        if input_data.shape[1]!=3:
            raise RuntimeError("shape of x must be of [Batch x 3 x NumInPoints]")
        
        batch_size, num_dims, num_points=input_data.size()
        output=get_graph_feature(input_data)
        
        output=F.relu(self.bn1(self.conv1(output)))
        output1=output.max(dim=-1,keepdim=True)[0]
        
        output=F.relu(self.bn2(self.conv2(output)))
        output2=output.max(dim=-1,keepdim=True)[0]
        
        output=F.relu(self.bn3(self.conv3(output)))
        output3=output.max(dim=-1,keepdim=True)[0]
        
        output=F.relu(self.bn4(self.conv4(output)))
        output4=output.max(dim=-1,keepdim=True)[0]
        
        output=torch.cat((output1, output2, output3, output4), dim=1)
        
        output=F.relu(self.bn5(self.conv5(output))).view(batch_size,num_points,-1)
        
        return output        

In [22]:
class SelfAttention(nn.Module):
    def __init__(self):
        super(SelfAttention,self).__init__()
        
        self.softmax=nn.Softmax(dim=2)
        
    def forward(self,x):
        B,N,D=x.shape
        scale=D ** -0.5
        
        q = x
        k = x
        v = x
        
        weights=self.softmax(torch.bmm(q, k.transpose(1,2))) * scale
        
        attn_value = torch.bmm(weights, v)
        
        return attn_value
       

In [23]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        
        self.self_attn=SelfAttention()
        
        self.dgcnn=DGCNN()
        
    def forward(self,x):
        B,N,dim=x.shape
        
        self.fc1=nn.Linear(dim,64)
        self.fc2=nn.Linear(64,3)
        
        out=F.relu(self.fc1(x))
        out=F.relu(self.fc2(out))
        
        
        
        attn=self.self_attn(x)
        
        out=self.dgcnn(out)
       
       
        
        out=torch.cat([out,attn],dim=2)
        
        return out


In [24]:
class CrossAttention(nn.Module):
    def __init__(self):
        super(CrossAttention,self).__init__()
        
        self.softmax=nn.Softmax(dim=2)
        
        
    def forward(self,q,v):
        
        B,N,D=q.shape
        scale=D ** -0.5
        
        k=v
        
        weights=self.softmax(torch.bmm(q,v.transpose(1,2)))
        attn_value=torch.bmm(weights, v)
        
        return attn_value
    


In [25]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder,self).__init__()
        
       
        
        self.cross_attn=CrossAttention()
        
        self.dgcnn=DGCNN()
        
    def forward(self, q,v):
        
        B,N,dim=v.shape
        
        self.fc1=nn.Linear(dim,64)
        self.fc2=nn.Linear(64,3)
        
        out=F.relu(self.fc1(v))
        out=F.relu(self.fc2(out))
        
        attn=self.cross_attn(q,v)
        
        out=self.dgcnn(out)
        
        out=torch.cat((out,attn),dim=2)
        
        return out
       

In [26]:
class QueryGenerator(nn.Module):
    def __init__(self,dim):
        super(QueryGenerator,self).__init__()
        
        self.dim=dim
        
        self.fc1=nn.Linear(dim,64)
        self.fc2=nn.Linear(64,128)
        self.fc3=nn.Linear(128,256)
        
        # for generating coordinates
        self.fc4=nn.Linear(dim,64)
        self.fc5=nn.Linear(64,128)
        self.fc6=nn.Linear(128,256)
        self.fc7=nn.Linear(256,128)
        self.fc8=nn.Linear(128,64)
        self.fc9=nn.Linear(64,3)
        
        # for generating query embeddings
        self.fc10=nn.Linear(4,64)
        self.fc11=nn.Linear(64,128)
        self.fc12=nn.Linear(128,dim)
        
        
        
    def forward(self,x):
        
        # generation of global feature g
        f=x
        f=F.relu(self.fc1(f))
        f=F.relu(self.fc2(f))
        f=F.relu(self.fc3(f))
        
        g=torch.max(f,dim=2,keepdim=True)[0]
        
        coordinate=F.relu(self.fc4(x))
        coordinate=F.relu(self.fc5(coordinate))
        coordinate=F.relu(self.fc6(coordinate))
        coordinate=F.relu(self.fc7(coordinate))
        coordinate=F.relu(self.fc8(coordinate))
        coordinate=F.relu(self.fc9(coordinate))
        
        # concatination of g and coordinate
        
        q_emb=torch.cat((coordinate,g),dim=2)
        
        q_emb=F.relu(self.fc10(q_emb))
        q_emb=F.relu(self.fc11(q_emb))
        q_emb=F.relu(self.fc12(q_emb))
        
        return coordinate, q_emb
    

In [49]:
class FoldingNet(nn.Module):
    def __init__(self):
        
        super(FoldingNet,self).__init__()
        
        
    def forward(self,input):
        B,N,d=input.shape
        
        m =2050
        grid=torch.randn(B,m,2)
        
        codeword=torch.max(input,dim=1, keepdim=True)[0].repeat(1,m,1)
        
        feature=torch.cat((codeword,grid),dim=2)
        
        self.fc1=nn.Linear(d+2, 256)
        self.fc2=nn.Linear(256,128)
        self.fc3=nn.Linear(128,3)
        
        feature=F.relu(self.fc1(feature))
        feature=F.relu(self.fc2(feature))
        feature=F.relu(self.fc3(feature))
        
        feature=torch.cat((feature,codeword),dim=2)
        
        self.fc4=nn.Linear(d+3,256)
        self.fc5=nn.Linear(256,128)
        self.fc6=nn.Linear(128,3)
        
        out=F.relu(self.fc4(feature))
        out=F.relu(self.fc5(out))
        out=F.relu(self.fc6(out))
        
        return out


In [None]:
class ChamferDistance(nn.Module):
    def __init__(self):
        super(ChamferDistance,self).__init__()
        

In [62]:
class PoinTr(nn.Module):
    def __init__(self,num_samples=50):
        super(PoinTr,self).__init__()
        
        self.num_samples=num_samples
        
        self.fc1=nn.Linear(3,64)
        self.fc2=nn.Linear(64,128)
        self.fc3=nn.Linear(128,256)
        self.fc4=nn.Linear(256,128)
        self.fc5=nn.Linear(128,64)
        
        self.dgcnn=DGCNN()
        
        self.encoder=Encoder()
        
        self.decoder=Decoder()
        
        self.foldingNet=FoldingNet()
        
        
    def forward(self, input):
        
        B,N,D=input.shape
    
        kernels=fps(input,self.num_samples)
        x=kernels
        
        
        x=F.relu(self.fc1(x))
        x=F.relu(self.fc2(x))
        x=F.relu(self.fc3(x))
        x=F.relu(self.fc4(x))
        x=F.relu(self.fc5(x))
        
        F_i=self.dgcnn(kernels)
        F_i=torch.cat((F_i,x),dim=2)
        
        encoded_point_proxies=self.encoder(F_i)
        
        _,_,dim=encoded_point_proxies.shape
        
        self.q_gen=QueryGenerator(dim)
        
        coordinate,q_emb=self.q_gen(encoded_point_proxies)
        
        predicted_proxies=self.decoder(q_emb,encoded_point_proxies)
        
        predicted_pcd=self.foldingNet(predicted_proxies)
        coordinate=coordinate.repeat(1,41,1)
        
        predicted_pcd=predicted_pcd+coordinate
        
        complete_pcd=torch.cat((input,predicted_pcd),dim=1)
        
        
        return complete_pcd
    
    
    