In [1]:
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
import numpy as np
import sys
from tqdm import tqdm 
import json
import torch.nn.functional as F
import argparse
import random
import torch.optim as optim
import time


In [2]:
class PointDataset(data.Dataset):
    def __init__(self,
                 root='../pointData',
                 npoints=2500,
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints
        self.root = root
        self.split = split
        self.cat = {}
        self.data_augmentation = data_augmentation
        self.fns = []
        
        with open(os.path.join(root, '{}.txt'.format(self.split)).replace('\\', '/'), 'r') as f:
            for line in f:
                self.fns.append(line.strip())
        
        with open(os.path.join(root, 'cat.txt').replace('\\', '/'), 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]
                
        print(self.cat) 
        self.classes = list(self.cat.keys())
        
    def __getitem__(self, index):
        fn = self.fns[index]
        cls = self.cat[fn.split('/')[0]]
        data = [[], [], []]
        
        with open(os.path.join(self.root, fn).replace('\\', '/'), 'r') as f:
            for line in f:
                c, d, e= line.split()
                data[0].append(c)  
                data[1].append(d)  
                data[2].append(e)  
        x = [float(data[0]) for data[0] in data[0]]
        y = [float(data[1]) for data[1] in data[1]]
        z = [float(data[2]) for data[2] in data[2]]
        pts = np.vstack([x,y,z]).T
        choice = np.random.choice(len(pts), self.npoints, replace=True)
        point_set = pts[choice, :]
        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # center
        #all points minus average of columns(i.e. xyz)
        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
        point_set = point_set / dist  # scale
        
        '''
        if self.data_augmentation:
            theta = np.random.uniform(0, np.pi * 2)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotation
            point_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitter
        '''
        
        point_set = torch.from_numpy(point_set.astype(np.float32))
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))

        return point_set, cls
    
    def __len__(self):
        return len(self.fns)


In [3]:
class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x

class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x
    
class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat
        
class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat
    
def feature_transform_regularizer(trans):
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[None, :, :]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss

In [4]:
blue = lambda x: '\033[94m' + x + '\033[0m'

manualSeed = random.randint(1, 10000)  # fix seed
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  4254


<torch._C.Generator at 0x246e82885a0>

In [5]:
dataset = PointDataset(
    root='../pointData',
    split='train')

test_dataset = PointDataset(
    root='../pointData',
    split='test',
    data_augmentation=False)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    drop_last=True)

testdataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=True,
        drop_last=True)

print(len(dataset), len(test_dataset))
num_classes = len(dataset.classes)
print('classes', num_classes)

{'cap': '0', 'plat': '1', 'vertical': '2'}
{'cap': '0', 'plat': '1', 'vertical': '2'}
180 60
classes 3


In [6]:
try:
    os.makedirs('cls')
except OSError:
    pass

In [7]:
classifier = PointNetCls(k=num_classes, feature_transform=True)
'''
if opt.model != '':
    classifier.load_state_dict(torch.load(opt.model))
'''
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

classifier.cuda()


num_batch = len(dataset) / 32



In [8]:
for epoch in range(250):
    scheduler.step()
    time_start=time.time()
    for i, data in enumerate(dataloader, 0):
        points, target = data
        target = target[:, 0]
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        optimizer.zero_grad()
        classifier = classifier.train()
        pred, trans, trans_feat = classifier(points)
        loss = F.nll_loss(pred, target)
        if True:
            loss += feature_transform_regularizer(trans_feat) * 0.001
        loss.backward()
        optimizer.step()
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.data).cpu().sum()
        print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(32)))

        if i % 10 == 0:
            if epoch%10==0:
                time_record=time.time()
                t=time_record-time_start
                print('time spent in recent ten epoches is %f s'% t)
                time_start=time.time()
            j, data = next(enumerate(testdataloader, 0))
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            classifier = classifier.eval()
            pred, _, _ = classifier(points)
            loss = F.nll_loss(pred, target)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.data).cpu().sum()
            print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(32)))

    torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % ('cls', epoch))



[0: 0/5] train loss: 1.238093 accuracy: 0.437500
time spent in recent ten epoches is 1.736353 s
[0: 0/5] test loss: 1.092245 accuracy: 0.437500
[0: 1/5] train loss: 1.273098 accuracy: 0.406250
[0: 2/5] train loss: 0.999177 accuracy: 0.562500
[0: 3/5] train loss: 1.073175 accuracy: 0.718750
[0: 4/5] train loss: 0.943079 accuracy: 0.687500
[1: 0/5] train loss: 0.672133 accuracy: 0.750000
[1: 0/5] test loss: 1.122416 accuracy: 0.312500
[1: 1/5] train loss: 0.731493 accuracy: 0.812500
[1: 2/5] train loss: 0.930152 accuracy: 0.625000
[1: 3/5] train loss: 0.649185 accuracy: 0.812500
[1: 4/5] train loss: 0.644321 accuracy: 0.875000
[2: 0/5] train loss: 0.489903 accuracy: 0.937500
[2: 0/5] test loss: 1.315618 accuracy: 0.281250
[2: 1/5] train loss: 0.563829 accuracy: 0.843750
[2: 2/5] train loss: 0.708762 accuracy: 0.781250
[2: 3/5] train loss: 0.485781 accuracy: 0.875000
[2: 4/5] train loss: 0.859684 accuracy: 0.781250
[3: 0/5] train loss: 0.502406 accuracy: 0.875000
[3: 0/5] test loss: 2.178

[27: 0/5] train loss: 0.079342 accuracy: 0.968750
[27: 0/5] test loss: 0.844140 accuracy: 0.718750
[27: 1/5] train loss: 0.065910 accuracy: 1.000000
[27: 2/5] train loss: 0.195000 accuracy: 0.968750
[27: 3/5] train loss: 0.115880 accuracy: 0.937500
[27: 4/5] train loss: 0.300284 accuracy: 0.906250
[28: 0/5] train loss: 0.073593 accuracy: 0.968750
[28: 0/5] test loss: 0.458504 accuracy: 0.843750
[28: 1/5] train loss: 0.188586 accuracy: 0.937500
[28: 2/5] train loss: 0.045470 accuracy: 1.000000
[28: 3/5] train loss: 0.077018 accuracy: 1.000000
[28: 4/5] train loss: 0.131215 accuracy: 0.968750
[29: 0/5] train loss: 0.077486 accuracy: 1.000000
[29: 0/5] test loss: 0.496868 accuracy: 0.875000
[29: 1/5] train loss: 0.052642 accuracy: 1.000000
[29: 2/5] train loss: 0.083316 accuracy: 0.968750
[29: 3/5] train loss: 0.098087 accuracy: 1.000000
[29: 4/5] train loss: 0.099877 accuracy: 0.968750
[30: 0/5] train loss: 0.177767 accuracy: 0.937500
time spent in recent ten epoches is 0.248766 s
[30: 0

[53: 4/5] train loss: 0.018671 accuracy: 1.000000
[54: 0/5] train loss: 0.014875 accuracy: 1.000000
[54: 0/5] test loss: 0.875420 accuracy: 0.750000
[54: 1/5] train loss: 0.024676 accuracy: 1.000000
[54: 2/5] train loss: 0.049802 accuracy: 1.000000
[54: 3/5] train loss: 0.109827 accuracy: 0.937500
[54: 4/5] train loss: 0.024528 accuracy: 1.000000
[55: 0/5] train loss: 0.023758 accuracy: 1.000000
[55: 0/5] test loss: 0.445618 accuracy: 0.843750
[55: 1/5] train loss: 0.032723 accuracy: 1.000000
[55: 2/5] train loss: 0.044506 accuracy: 0.968750
[55: 3/5] train loss: 0.027601 accuracy: 1.000000
[55: 4/5] train loss: 0.134384 accuracy: 0.937500
[56: 0/5] train loss: 0.105605 accuracy: 0.968750
[56: 0/5] test loss: 0.307194 accuracy: 0.843750
[56: 1/5] train loss: 0.026570 accuracy: 1.000000
[56: 2/5] train loss: 0.047152 accuracy: 1.000000
[56: 3/5] train loss: 0.026402 accuracy: 1.000000
[56: 4/5] train loss: 0.031150 accuracy: 1.000000
[57: 0/5] train loss: 0.059393 accuracy: 0.968750
[57

[80: 3/5] train loss: 0.016681 accuracy: 1.000000
[80: 4/5] train loss: 0.011702 accuracy: 1.000000
[81: 0/5] train loss: 0.017045 accuracy: 1.000000
[81: 0/5] test loss: 0.355517 accuracy: 0.875000
[81: 1/5] train loss: 0.010763 accuracy: 1.000000
[81: 2/5] train loss: 0.016235 accuracy: 1.000000
[81: 3/5] train loss: 0.016756 accuracy: 1.000000
[81: 4/5] train loss: 0.015276 accuracy: 1.000000
[82: 0/5] train loss: 0.011025 accuracy: 1.000000
[82: 0/5] test loss: 0.333091 accuracy: 0.906250
[82: 1/5] train loss: 0.015364 accuracy: 1.000000
[82: 2/5] train loss: 0.016550 accuracy: 1.000000
[82: 3/5] train loss: 0.010822 accuracy: 1.000000
[82: 4/5] train loss: 0.011331 accuracy: 1.000000
[83: 0/5] train loss: 0.015179 accuracy: 1.000000
[83: 0/5] test loss: 0.352806 accuracy: 0.906250
[83: 1/5] train loss: 0.010370 accuracy: 1.000000
[83: 2/5] train loss: 0.012157 accuracy: 1.000000
[83: 3/5] train loss: 0.011140 accuracy: 1.000000
[83: 4/5] train loss: 0.027379 accuracy: 1.000000
[84

[107: 2/5] train loss: 0.010298 accuracy: 1.000000
[107: 3/5] train loss: 0.009570 accuracy: 1.000000
[107: 4/5] train loss: 0.010024 accuracy: 1.000000
[108: 0/5] train loss: 0.010940 accuracy: 1.000000
[108: 0/5] test loss: 0.573030 accuracy: 0.875000
[108: 1/5] train loss: 0.009395 accuracy: 1.000000
[108: 2/5] train loss: 0.009786 accuracy: 1.000000
[108: 3/5] train loss: 0.009910 accuracy: 1.000000
[108: 4/5] train loss: 0.011364 accuracy: 1.000000
[109: 0/5] train loss: 0.013826 accuracy: 1.000000
[109: 0/5] test loss: 0.608712 accuracy: 0.843750
[109: 1/5] train loss: 0.010994 accuracy: 1.000000
[109: 2/5] train loss: 0.009067 accuracy: 1.000000
[109: 3/5] train loss: 0.010112 accuracy: 1.000000
[109: 4/5] train loss: 0.012706 accuracy: 1.000000
[110: 0/5] train loss: 0.010052 accuracy: 1.000000
time spent in recent ten epoches is 0.239942 s
[110: 0/5] test loss: 0.302750 accuracy: 0.875000
[110: 1/5] train loss: 0.010588 accuracy: 1.000000
[110: 2/5] train loss: 0.013175 accura

[133: 4/5] train loss: 0.008856 accuracy: 1.000000
[134: 0/5] train loss: 0.014051 accuracy: 1.000000
[134: 0/5] test loss: 0.453765 accuracy: 0.843750
[134: 1/5] train loss: 0.010159 accuracy: 1.000000
[134: 2/5] train loss: 0.009069 accuracy: 1.000000
[134: 3/5] train loss: 0.009306 accuracy: 1.000000
[134: 4/5] train loss: 0.009167 accuracy: 1.000000
[135: 0/5] train loss: 0.008984 accuracy: 1.000000
[135: 0/5] test loss: 0.374133 accuracy: 0.906250
[135: 1/5] train loss: 0.008800 accuracy: 1.000000
[135: 2/5] train loss: 0.014130 accuracy: 1.000000
[135: 3/5] train loss: 0.010042 accuracy: 1.000000
[135: 4/5] train loss: 0.137726 accuracy: 0.968750
[136: 0/5] train loss: 0.012895 accuracy: 1.000000
[136: 0/5] test loss: 0.600033 accuracy: 0.843750
[136: 1/5] train loss: 0.010166 accuracy: 1.000000
[136: 2/5] train loss: 0.010504 accuracy: 1.000000
[136: 3/5] train loss: 0.009303 accuracy: 1.000000
[136: 4/5] train loss: 0.007910 accuracy: 1.000000
[137: 0/5] train loss: 0.011097 ac

[160: 0/5] test loss: 0.405430 accuracy: 0.875000
[160: 1/5] train loss: 0.010168 accuracy: 1.000000
[160: 2/5] train loss: 0.011245 accuracy: 1.000000
[160: 3/5] train loss: 0.016495 accuracy: 1.000000
[160: 4/5] train loss: 0.009247 accuracy: 1.000000
[161: 0/5] train loss: 0.011386 accuracy: 1.000000
[161: 0/5] test loss: 0.254259 accuracy: 0.937500
[161: 1/5] train loss: 0.012116 accuracy: 1.000000
[161: 2/5] train loss: 0.008567 accuracy: 1.000000
[161: 3/5] train loss: 0.009785 accuracy: 1.000000
[161: 4/5] train loss: 0.010635 accuracy: 1.000000
[162: 0/5] train loss: 0.011836 accuracy: 1.000000
[162: 0/5] test loss: 0.644138 accuracy: 0.812500
[162: 1/5] train loss: 0.008666 accuracy: 1.000000
[162: 2/5] train loss: 0.010824 accuracy: 1.000000
[162: 3/5] train loss: 0.014656 accuracy: 1.000000
[162: 4/5] train loss: 0.008629 accuracy: 1.000000
[163: 0/5] train loss: 0.009529 accuracy: 1.000000
[163: 0/5] test loss: 0.459504 accuracy: 0.843750
[163: 1/5] train loss: 0.009663 acc

[186: 3/5] train loss: 0.011709 accuracy: 1.000000
[186: 4/5] train loss: 0.016930 accuracy: 1.000000
[187: 0/5] train loss: 0.011725 accuracy: 1.000000
[187: 0/5] test loss: 0.485810 accuracy: 0.843750
[187: 1/5] train loss: 0.009125 accuracy: 1.000000
[187: 2/5] train loss: 0.008728 accuracy: 1.000000
[187: 3/5] train loss: 0.010751 accuracy: 1.000000
[187: 4/5] train loss: 0.008430 accuracy: 1.000000
[188: 0/5] train loss: 0.010273 accuracy: 1.000000
[188: 0/5] test loss: 0.346064 accuracy: 0.906250
[188: 1/5] train loss: 0.008620 accuracy: 1.000000
[188: 2/5] train loss: 0.011772 accuracy: 1.000000
[188: 3/5] train loss: 0.008543 accuracy: 1.000000
[188: 4/5] train loss: 0.008431 accuracy: 1.000000
[189: 0/5] train loss: 0.009365 accuracy: 1.000000
[189: 0/5] test loss: 0.573379 accuracy: 0.843750
[189: 1/5] train loss: 0.011102 accuracy: 1.000000
[189: 2/5] train loss: 0.009474 accuracy: 1.000000
[189: 3/5] train loss: 0.010432 accuracy: 1.000000
[189: 4/5] train loss: 0.008104 ac

[213: 0/5] train loss: 0.012071 accuracy: 1.000000
[213: 0/5] test loss: 0.145980 accuracy: 0.937500
[213: 1/5] train loss: 0.008813 accuracy: 1.000000
[213: 2/5] train loss: 0.019555 accuracy: 1.000000
[213: 3/5] train loss: 0.009211 accuracy: 1.000000
[213: 4/5] train loss: 0.010144 accuracy: 1.000000
[214: 0/5] train loss: 0.010661 accuracy: 1.000000
[214: 0/5] test loss: 0.623957 accuracy: 0.843750
[214: 1/5] train loss: 0.012956 accuracy: 1.000000
[214: 2/5] train loss: 0.009846 accuracy: 1.000000
[214: 3/5] train loss: 0.010844 accuracy: 1.000000
[214: 4/5] train loss: 0.008187 accuracy: 1.000000
[215: 0/5] train loss: 0.009949 accuracy: 1.000000
[215: 0/5] test loss: 0.561077 accuracy: 0.843750
[215: 1/5] train loss: 0.010502 accuracy: 1.000000
[215: 2/5] train loss: 0.007974 accuracy: 1.000000
[215: 3/5] train loss: 0.009967 accuracy: 1.000000
[215: 4/5] train loss: 0.012345 accuracy: 1.000000
[216: 0/5] train loss: 0.008680 accuracy: 1.000000
[216: 0/5] test loss: 0.589900 acc

[239: 2/5] train loss: 0.008003 accuracy: 1.000000
[239: 3/5] train loss: 0.017553 accuracy: 1.000000
[239: 4/5] train loss: 0.009036 accuracy: 1.000000
[240: 0/5] train loss: 0.008961 accuracy: 1.000000
time spent in recent ten epoches is 0.252885 s
[240: 0/5] test loss: 0.592192 accuracy: 0.843750
[240: 1/5] train loss: 0.008938 accuracy: 1.000000
[240: 2/5] train loss: 0.009352 accuracy: 1.000000
[240: 3/5] train loss: 0.008959 accuracy: 1.000000
[240: 4/5] train loss: 0.008974 accuracy: 1.000000
[241: 0/5] train loss: 0.010959 accuracy: 1.000000
[241: 0/5] test loss: 0.328573 accuracy: 0.843750
[241: 1/5] train loss: 0.008514 accuracy: 1.000000
[241: 2/5] train loss: 0.015461 accuracy: 1.000000
[241: 3/5] train loss: 0.008488 accuracy: 1.000000
[241: 4/5] train loss: 0.010236 accuracy: 1.000000
[242: 0/5] train loss: 0.008652 accuracy: 1.000000
[242: 0/5] test loss: 0.288634 accuracy: 0.906250
[242: 1/5] train loss: 0.012493 accuracy: 1.000000
[242: 2/5] train loss: 0.010262 accura

In [9]:
total_correct = 0
total_testset = 0
for i,data in tqdm(enumerate(testdataloader, 0)):
    points, target = data
    target = target[:, 0]
    points = points.transpose(2, 1)
    points, target = points.cuda(), target.cuda()
    classifier = classifier.eval()
    pred, _, _ = classifier(points)
    pred_choice = pred.data.max(1)[1]
    correct = pred_choice.eq(target.data).cpu().sum()
    total_correct += correct.item()
    total_testset += points.size()[0]

print("final accuracy {}".format(total_correct / float(total_testset)))

1it [00:00,  7.33it/s]


final accuracy 0.75
