In [1]:
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from chamfer_distance import ChamferDistance
from torch.autograd import Variable
import glob
import torch.nn.functional as F

In [2]:
l = glob.glob("/datasets/cs253-wi20-public/ShapeNet_pointclouds/*/*/*2048.npy")

class PointCloudDataset(Dataset):

    def __init__(self, lis= None):
        
        point_clouds = []
        for file_name in lis:

            points = np.load(file_name)
            point_clouds.append(points)

        self.point_clouds = np.array(point_clouds, dtype='float64')
        self.point_clouds = np.transpose(self.point_clouds, (0, 2, 1))
        
    def __len__(self):
        return len(self.point_clouds)

    def __getitem__(self, idx):
        return self.point_clouds[idx]

In [3]:
class Tnet(nn.Module):
    def __init__(self, k=3):
        super().__init__()
        self.k=k
        self.conv1 = nn.Conv1d(k,64,1)
        self.conv2 = nn.Conv1d(64,128,1)
        self.conv3 = nn.Conv1d(128,1024,1)
        self.fc1 = nn.Linear(1024,512)
        self.fc2 = nn.Linear(512,256)
        self.fc3 = nn.Linear(256,k*k)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)
       

    def forward(self, input):
        # input.shape == (bs,n,3)
        bs = input.size(0)
        xb = F.relu(self.bn1(self.conv1(input)))
        xb = F.relu(self.bn2(self.conv2(xb)))
        xb = F.relu(self.bn3(self.conv3(xb)))
        xb = nn.MaxPool1d(xb.size(-1))(xb)
        xb = nn.Flatten(1)(xb)
        xb = F.relu(self.bn4(self.fc1(xb)))
        xb = F.relu(self.bn5(self.fc2(xb)))
      
      #initialize as identity
        init = torch.eye(self.k, requires_grad=True).repeat(bs,1,1)
        if xb.is_cuda:
            init=init.cuda()
        matrix = self.fc3(xb).view(-1,self.k,self.k) + init
        return matrix


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_transform = Tnet(k=3)
        self.feature_transform = Tnet(k=64)
        self.fc1 = nn.Conv1d(3,64,1)
        self.fc2 = nn.Conv1d(64,64,1) 
        self.fc4 = nn.Conv1d(64,128,1)
        self.fc5 = nn.Conv1d(128,1024,1)

        
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(64)
        self.bn4 = nn.BatchNorm1d(128)
        self.bn5 = nn.BatchNorm1d(1024)

    def forward(self, input):
        n_pts = input.size()[2]
        matrix3x3 = self.input_transform(input)
        xb = torch.bmm(torch.transpose(input,1,2), matrix3x3).transpose(1,2)
        xb = F.relu(self.bn1(self.fc1(xb)))
        xb = F.relu(self.bn2(self.fc2(xb)))
        matrix128x128 = self.feature_transform(xb)
        xb = torch.bmm(torch.transpose(xb,1,2), matrix128x128).transpose(1,2) 
        xb = F.relu(self.bn4(self.fc4(xb)))
        xb = self.bn5(self.fc5(xb))
        xb = nn.MaxPool1d(xb.size(-1))(xb)        
        
        return xb

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 128)
        self.fc4 = nn.Linear(128,2048*3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.bn3 = nn.BatchNorm1d(128)
        
    def forward(self, input):
        input = input.view(input.size(0), -1)
        out = F.relu(self.bn1(self.fc1(input)))
        out = F.relu(self.bn2(self.fc2(out)))
        out = F.relu(self.bn3(self.fc3(out)))
        out =  torch.tanh(self.fc4(out))
        return out.view(-1, 3, 2048)
        
        
    

In [4]:
class EarlyStopping():
    """
    Early Stopping to terminate training early under certain conditions
    """

    def __init__(self, min_delta=0, patience=50):
        
        """
        EarlyStopping callback to exit the training loop if training or
        validation loss does not improve by a certain amount for a certain
        number of epochs
        Arguments
        ---------
        min_delta : float
            minimum change in monitored value to qualify as improvement.
            This number should be positive.
        patience : integer
            number of epochs to wait for improvment before terminating.
            the counter be reset after each improvment
        """
        
        self.min_delta = min_delta
        self.patience = patience
        self.wait = 0
        self.best_loss = 1e-15
        self.stopped_epoch = 0

    def on_train_begin(self):
        self.wait = 0
        self.best_loss = 1e15

    def on_epoch_end(self, epoch, current_loss):

        if current_loss is None:
            pass
        else:
            if (current_loss - self.best_loss) < -self.min_delta:
                self.best_loss = current_loss
                self.wait = 1
            else:
                if self.wait >= self.patience:
                    return True
                self.wait += 1
        
        return False

In [None]:
point_cloud_dataset = PointCloudDataset(l)
dataloader = DataLoader(point_cloud_dataset, batch_size=32, shuffle=True)
#enc = Encoder()
#dec = Decoder()
enc = torch.load('encoder')
dec = torch.load('decoder')
chamferDist = ChamferDistance()
#early_stopping = EarlyStopping(patience=50)

num_epochs = 1000
learning_rate = 1e-3

optimizere = torch.optim.Adam(enc.parameters(), lr=learning_rate)
optimizerd = torch.optim.Adam(dec.parameters(), lr=learning_rate)
if torch.cuda.is_available():
    enc = enc.cuda()
    dec = dec.cuda()
loss_list = []
for epoch in range(num_epochs):
    
   # if epoch == 0:
  #      early_stopping.on_train_begin()

    running_loss = 0.0
    
    for data in dataloader:
        if torch.cuda.is_available():
            data = data.cuda()
            
        train_output = enc(data.float())
        out= dec(train_output)
        dist1, dist2 = chamferDist(out.float(), data.float())
        loss = (torch.mean(dist1)) + (torch.mean(dist2))

        optimizere.zero_grad()
        optimizerd.zero_grad()
        loss.backward()
        optimizere.step()
        optimizerd.step()
        
        running_loss += loss.data.detach() 
    
    loss_list.append(running_loss / data.shape[0])
    
  #  if early_stopping.on_epoch_end(epoch + 1, running_loss):
   #     print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss_list[-1]))
  #      print('Terminated Training for Early Stopping at Epoch %04i' % (epoch + 1))
    torch.save(enc, 'encoder')
    torch.save(dec, 'decoder')
            
    print('epoch [{}/{}], loss:{:.4f}'.format(epoch + 1, num_epochs, loss_list[-1]))    

  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "
  "type " + obj.__name__ + ". It won't be checked "


epoch [1/1000], loss:12.9899
epoch [2/1000], loss:11.9559
