In [None]:
!pip install torch-geometric

In [None]:
import os
import sys
import glob
import h5py
import torch
import tqdm
import numpy as np
import pandas as pd
import scipy.io as sio
import torch.nn.functional as F
import plotly.express as px
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset
from torch_geometric.data import Data,HeteroData
from torch_geometric.data import DataLoader
from torch_geometric.datasets import MovieLens
from torch_geometric.nn import RGCNConv
from torch_geometric.utils import dropout_adj
from sklearn.model_selection import train_test_split

In [None]:

def download():
    BASE_DIR = './'
    DATA_DIR = os.path.join(BASE_DIR, 'data')
    if not os.path.exists(DATA_DIR):
        os.mkdir(DATA_DIR)
    if not os.path.exists(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048')):
        www = 'https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip'
        zipfile = os.path.basename(www)
        os.system('wget --no-check-certificate %s; unzip %s' % (www, zipfile))
        os.system('mv %s %s' % (zipfile[:-4], DATA_DIR))
        os.system('rm %s' % (zipfile))


def load_data(partition):
    download()
    BASE_DIR = './'
    DATA_DIR = os.path.join(BASE_DIR, 'data')
    all_data = []
    all_label = []
    for h5_name in glob.glob(os.path.join(DATA_DIR, 'modelnet40_ply_hdf5_2048', 'ply_data_%s*.h5'%partition)):
        f = h5py.File(h5_name)
        data = f['data'][:].astype('float32')
        label = f['label'][:].astype('int64')
        f.close()
        all_data.append(data)
        all_label.append(label)
    all_data = np.concatenate(all_data, axis=0)
    all_label = np.concatenate(all_label, axis=0)
    return all_data, all_label


def load_scanobjectnn_data(partition):
    BASE_DIR = './'
    DATA_DIR = os.path.join(BASE_DIR, 'data')
    all_data = []
    all_label = []

    h5_name = BASE_DIR + '/data/' + partition + '_objectdataset_augmentedrot_scale75.h5'
    f = h5py.File(h5_name)
    data = f['data'][:].astype('float32')
    label = f['label'][:].astype('int64')
    f.close()
    all_data.append(data)
    all_label.append(label)
    all_data = np.concatenate(all_data, axis=0)
    all_label = np.concatenate(all_label, axis=0)
    return all_data, all_label


def translate_pointcloud(pointcloud):
    xyz1 = np.random.uniform(low=2./3., high=3./2., size=[3])
    xyz2 = np.random.uniform(low=-0.2, high=0.2, size=[3])
       
    translated_pointcloud = np.add(np.multiply(pointcloud, xyz1), xyz2).astype('float32')
    return translated_pointcloud


def jitter_pointcloud(pointcloud, sigma=0.01, clip=0.02):
    N, C = pointcloud.shape
    pointcloud += np.clip(sigma * np.random.randn(N, C), -1*clip, clip)
    return pointcloud


class ModelNet40(Dataset):
    def __init__(self, num_points, partition='train'):
        self.data, self.label = load_data(partition)
        self.num_points = num_points
        self.partition = partition        

    def __getitem__(self, item):
        pointcloud = self.data[item][:self.num_points]
        label = self.label[item]
        if self.partition == 'train':
            pointcloud = translate_pointcloud(pointcloud)
            np.random.shuffle(pointcloud)
        return torch.tensor(pointcloud), torch.tensor(label)

    def __len__(self):
        return self.data.shape[0]


class ScanObjectNN(Dataset):
    def __init__(self, num_points, partition='training'):
        self.data, self.label = load_scanobjectnn_data(partition)
        self.num_points = num_points
        self.partition = partition        

    def __getitem__(self, item):
        pointcloud = self.data[item][:self.num_points]
        label = self.label[item]
        if self.partition == 'training':
            pointcloud = translate_pointcloud(pointcloud)
            np.random.shuffle(pointcloud)
        return pointcloud, label

    def __len__(self):
        return self.data.shape[0]




train=ModelNet40(partition='train', num_points=128)

test=ModelNet40(partition='test', num_points=128)


In [None]:
train_loader=DataLoader(train,batch_size=64)
test_loader=DataLoader(test,batch_size=64)



In [None]:
class EDM(torch.nn.Module):
    def __init__(
        self,
    ):
        super(EDM, self).__init__()
        self.fc1 = torch.nn.Linear(3,32)
        self.bn1 = torch.nn.BatchNorm1d(32)
        self.fc2 = torch.nn.Linear(64,64)
        self.bn2 = torch.nn.BatchNorm1d(64)
        self.fc3 = torch.nn.Linear(192,128)
        self.bn3 = torch.nn.BatchNorm1d(128)
        self.output=torch.nn.Linear(128,40)


    def extract_distance_matrix(self,inp):
        x1=torch.unsqueeze(inp,2)
        x2=torch.unsqueeze(inp,1)
        dist=torch.sqrt(torch.sum(((x1-x2)**2),dim=-1))
        max_=torch.max(dist,keepdim=True,dim=-1)[0]
        max_=torch.max(max_,keepdim=True,dim=-1)[0]
        weight=max_-dist
        # ones=torch.diag_embed(torch.ones(inp.size(0),inp.size(1)))
        # weight=weight*(1-ones)

        return weight



    def normalize_distance_matrix(self,inp):
        D = torch.sum(inp , dim = 2)

        D_sqrt = torch.div(1.0 , torch.sqrt(D))
        D_sqrt = torch.diag_embed(D_sqrt)

        I = torch.ones_like(D , dtype = torch.float32)
        I = torch.diag_embed( I )
        normalized_weight = I - torch.bmm(D_sqrt , torch.bmm(inp , D_sqrt))
        return normalized_weight

    def extract_simplified_SVD(self,inp):
        # U, S, Vh = torch.linalg.svd(inp)
        S,U=torch.linalg.eig(inp)
        S=S.float()
        U=U.float()
        Vh=torch.permute(U,(0,2,1))
        # S[:,32:]=0
        return U[:,:,:32] , torch.diag_embed(S[:,:32]) , Vh[:,:32,:]


    def forward(self,inp):
        weight_matrix=self.extract_distance_matrix(inp)
        nw=self.normalize_distance_matrix(weight_matrix)
        # nw_s=self.extract_simplified_SVD(nw)
        U,S,Vh=self.extract_simplified_SVD(nw)
        

        x=self.fc1(inp)
        x=F.relu(torch.transpose(self.bn1(torch.transpose(x,1,2)),1,2))
        x_gcn=F.relu(torch.bmm(U,torch.bmm(S,torch.bmm(Vh,x))))
        x1=torch.cat([x,x_gcn],dim=-1)
        

        x=self.fc2(x1)
        x=F.relu(torch.transpose(self.bn2(torch.transpose(x,1,2)),1,2))
        x_gcn=F.relu(torch.bmm(U,torch.bmm(S,torch.bmm(Vh,x))))
        x2=torch.cat([x,x_gcn],dim=-1)
        

        x=torch.cat([x1,x2],dim=-1)
        x=self.fc3(x)#n,192,128
        x=F.relu(torch.transpose(self.bn3(torch.transpose(x,1,2)),1,2))

        x=F.max_pool1d(torch.permute(x,(0,2,1)),kernel_size=128).view(-1,128)
        x=self.output(F.dropout(x,0.3)) ## REMOVE SOFTMAX IT EXISTS IN CROSSENTROPY
        return x


In [None]:
def loss_fn(output,gt):
      gt=gt.view(-1)
      loss= F.cross_entropy(output,gt)
      max_=torch.argmax(output,dim=-1).view(-1)
      sum_=(max_==gt).sum()
      return loss,sum_,gt.size(0)



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EDM().to(device)
def train():
    
    
    optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001)
    max_val_acc=0
    for epoch in range(500):
      
      model.train()
      train_loss=0
      train_corrects=0
      train_records=0

      for data,label in train_loader:
        data = data.to(device)
        label=label.to(device)
        out = model(data)
        loss,sum_,num_= loss_fn(out,label,)
        train_loss+=loss
        train_corrects+=sum_
        train_records+=num_
      optimizer.zero_grad()
      train_loss.backward()
      optimizer.step()



      test_loss=0
      test_corrects=0
      test_records=0
      model.eval()
      for data,label in test_loader:
        data = data.to(device)
        label=label.to(device)
        out = model(
            data
        )
        loss,sum_,num_=loss_fn(out,label,)
        test_loss+=loss
        test_corrects+=sum_
        test_records+=num_
      if test_corrects/test_records >max_val_acc:
        max_val_acc=test_corrects/test_records
      print(f'Epoch: {epoch}, Loss: {train_loss.item()},"Acc":{train_corrects/train_records},  val_Loss: {test_loss.item()}, val_Acc:{test_corrects/test_records}, max_val_Acc: {max_val_acc}')

In [None]:
train()