In [1]:
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import torch.nn as nn
import torch.nn.parallel
from torch.autograd import Variable
import numpy as np
import sys
from tqdm import tqdm 
import json
import torch.nn.functional as F
import argparse
import random
import torch.optim as optim
import time


In [2]:
class PointDataset(data.Dataset):
    def __init__(self,
                 root='../pointData',
                 npoints=2500,
                 split='train',
                 data_augmentation=True):
        self.npoints = npoints
        self.root = root
        self.split = split
        self.cat = {}
        self.data_augmentation = data_augmentation
        self.fns = []
        
        with open(os.path.join(root, '{}.txt'.format(self.split)).replace('\\', '/'), 'r') as f:
            for line in f:
                self.fns.append(line.strip())
        
        with open(os.path.join(root, 'cat.txt').replace('\\', '/'), 'r') as f:
            for line in f:
                ls = line.strip().split()
                self.cat[ls[0]] = ls[1]
                
        print(self.cat) 
        self.classes = list(self.cat.keys())
        
    def __getitem__(self, index):
        fn = self.fns[index]
        cls = self.cat[fn.split('/')[0]]
        data = [[], [], []]
        
        with open(os.path.join(self.root, fn).replace('\\', '/'), 'r') as f:
            for line in f:
                c, d, e= line.split()
                data[0].append(c)  
                data[1].append(d)  
                data[2].append(e)  
        x = [float(data[0]) for data[0] in data[0]]
        y = [float(data[1]) for data[1] in data[1]]
        z = [float(data[2]) for data[2] in data[2]]
        pts = np.vstack([x,y,z]).T
        choice = np.random.choice(len(pts), self.npoints, replace=True)
        point_set = pts[choice, :]
        point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0)  # center
        #all points minus average of columns(i.e. xyz)
        dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
        point_set = point_set / dist  # scale
        
        '''
        if self.data_augmentation:
            theta = np.random.uniform(0, np.pi * 2)
            rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
            point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix)  # random rotation
            point_set += np.random.normal(0, 0.02, size=point_set.shape)  # random jitter
        '''
        
        point_set = torch.from_numpy(point_set.astype(np.float32))
        cls = torch.from_numpy(np.array([cls]).astype(np.int64))

        return point_set, cls
    
    def __len__(self):
        return len(self.fns)


In [3]:
class STN3d(nn.Module):
    def __init__(self):
        super(STN3d, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 9)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, 3, 3)
        return x

class STNkd(nn.Module):
    def __init__(self, k=64):
        super(STNkd, self).__init__()
        self.conv1 = torch.nn.Conv1d(k, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k*k)
        self.relu = nn.ReLU()

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)

        self.k = k

    def forward(self, x):
        batchsize = x.size()[0]
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.relu(self.bn3(self.conv3(x)))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)

        x = F.relu(self.bn4(self.fc1(x)))
        x = F.relu(self.bn5(self.fc2(x)))
        x = self.fc3(x)

        iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
        if x.is_cuda:
            iden = iden.cuda()
        x = x + iden
        x = x.view(-1, self.k, self.k)
        return x
    
class PointNetfeat(nn.Module):
    def __init__(self, global_feat = True, feature_transform = False):
        super(PointNetfeat, self).__init__()
        self.stn = STN3d()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)
        self.conv3 = torch.nn.Conv1d(128, 1024, 1)
        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.global_feat = global_feat
        self.feature_transform = feature_transform
        if self.feature_transform:
            self.fstn = STNkd(k=64)

    def forward(self, x):
        n_pts = x.size()[2]
        trans = self.stn(x)
        x = x.transpose(2, 1)
        x = torch.bmm(x, trans)
        x = x.transpose(2, 1)
        x = F.relu(self.bn1(self.conv1(x)))

        if self.feature_transform:
            trans_feat = self.fstn(x)
            x = x.transpose(2,1)
            x = torch.bmm(x, trans_feat)
            x = x.transpose(2,1)
        else:
            trans_feat = None

        pointfeat = x
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.bn3(self.conv3(x))
        x = torch.max(x, 2, keepdim=True)[0]
        x = x.view(-1, 1024)
        if self.global_feat:
            return x, trans, trans_feat
        else:
            x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
            return torch.cat([x, pointfeat], 1), trans, trans_feat
        
class PointNetCls(nn.Module):
    def __init__(self, k=2, feature_transform=False):
        super(PointNetCls, self).__init__()
        self.feature_transform = feature_transform
        self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, k)
        self.dropout = nn.Dropout(p=0.3)
        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(256)
        self.relu = nn.ReLU()

    def forward(self, x):
        x, trans, trans_feat = self.feat(x)
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.dropout(self.fc2(x))))
        x = self.fc3(x)
        return F.log_softmax(x, dim=1), trans, trans_feat
    
def feature_transform_regularizer(trans):
    d = trans.size()[1]
    batchsize = trans.size()[0]
    I = torch.eye(d)[None, :, :]
    if trans.is_cuda:
        I = I.cuda()
    loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1)) - I, dim=(1,2)))
    return loss

In [4]:
blue = lambda x: '\033[94m' + x + '\033[0m'

manualSeed = random.randint(1, 10000)  # fix seed
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

Random Seed:  2108


<torch._C.Generator at 0x205eae62170>

In [5]:
dataset = PointDataset(
    root='../pointData',
    split='train')

test_dataset = PointDataset(
    root='../pointData',
    split='test',
    data_augmentation=False)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    drop_last=True)

testdataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=32,
        shuffle=True,
        drop_last=True)

print(len(dataset), len(test_dataset))
num_classes = len(dataset.classes)
print('classes', num_classes)

{'cap': '0', 'plat': '1', 'vertical': '2'}
{'cap': '0', 'plat': '1', 'vertical': '2'}
180 60
classes 3


In [6]:
try:
    os.makedirs('cls')
except OSError:
    pass

In [7]:
classifier = PointNetCls(k=num_classes, feature_transform=True)
'''
if opt.model != '':
    classifier.load_state_dict(torch.load(opt.model))
'''
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)

classifier.cuda()


num_batch = len(dataset) / 32



In [8]:
for epoch in range(250):
    scheduler.step()
    time_start=time.time()
    for i, data in enumerate(dataloader, 0):
        points, target = data
        target = target[:, 0]
        points = points.transpose(2, 1)
        points, target = points.cuda(), target.cuda()
        optimizer.zero_grad()
        classifier = classifier.train()
        pred, trans, trans_feat = classifier(points)
        loss = F.nll_loss(pred, target)
        if True:
            loss += feature_transform_regularizer(trans_feat) * 0.001
        loss.backward()
        optimizer.step()
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(target.data).cpu().sum()
        print('[%d: %d/%d] train loss: %f accuracy: %f' % (epoch, i, num_batch, loss.item(), correct.item() / float(32)))

        if i % 10 == 0:
            if epoch%10==0:
                time_record=time.time()
                t=time_record-time_start
                print('time spent in recent ten epoches is %f s'% t)
                time_start=time.time()
            j, data = next(enumerate(testdataloader, 0))
            points, target = data
            target = target[:, 0]
            points = points.transpose(2, 1)
            points, target = points.cuda(), target.cuda()
            classifier = classifier.eval()
            pred, _, _ = classifier(points)
            loss = F.nll_loss(pred, target)
            pred_choice = pred.data.max(1)[1]
            correct = pred_choice.eq(target.data).cpu().sum()
            print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(32)))

    torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % ('cls', epoch))



[0: 0/5] train loss: 1.292954 accuracy: 0.406250
time spent in recent ten epoches is 3.175930 s
[0: 0/5] [94mtest[0m loss: 1.100991 accuracy: 0.281250
[0: 1/5] train loss: 1.373484 accuracy: 0.250000
[0: 2/5] train loss: 1.159541 accuracy: 0.468750
[0: 3/5] train loss: 0.946692 accuracy: 0.781250
[0: 4/5] train loss: 1.070579 accuracy: 0.593750
[1: 0/5] train loss: 0.756839 accuracy: 0.781250
[1: 0/5] [94mtest[0m loss: 1.134177 accuracy: 0.250000
[1: 1/5] train loss: 0.931403 accuracy: 0.656250
[1: 2/5] train loss: 0.809918 accuracy: 0.750000
[1: 3/5] train loss: 0.703610 accuracy: 0.843750
[1: 4/5] train loss: 0.621524 accuracy: 0.843750
[2: 0/5] train loss: 0.714314 accuracy: 0.781250
[2: 0/5] [94mtest[0m loss: 1.082825 accuracy: 0.406250
[2: 1/5] train loss: 0.520958 accuracy: 0.875000
[2: 2/5] train loss: 0.888484 accuracy: 0.718750
[2: 3/5] train loss: 0.593490 accuracy: 0.812500
[2: 4/5] train loss: 0.576049 accuracy: 0.906250
[3: 0/5] train loss: 0.644255 accuracy: 0.87500

[26: 1/5] train loss: 0.160562 accuracy: 0.968750
[26: 2/5] train loss: 0.184779 accuracy: 0.937500
[26: 3/5] train loss: 0.121209 accuracy: 1.000000
[26: 4/5] train loss: 0.096307 accuracy: 1.000000
[27: 0/5] train loss: 0.133752 accuracy: 0.968750
[27: 0/5] [94mtest[0m loss: 0.586001 accuracy: 0.750000
[27: 1/5] train loss: 0.152898 accuracy: 0.937500
[27: 2/5] train loss: 0.251046 accuracy: 0.968750
[27: 3/5] train loss: 0.090774 accuracy: 1.000000
[27: 4/5] train loss: 0.112890 accuracy: 1.000000
[28: 0/5] train loss: 0.097161 accuracy: 1.000000
[28: 0/5] [94mtest[0m loss: 0.627211 accuracy: 0.812500
[28: 1/5] train loss: 0.076650 accuracy: 1.000000
[28: 2/5] train loss: 0.104507 accuracy: 1.000000
[28: 3/5] train loss: 0.206701 accuracy: 0.968750
[28: 4/5] train loss: 0.120772 accuracy: 0.937500
[29: 0/5] train loss: 0.111554 accuracy: 1.000000
[29: 0/5] [94mtest[0m loss: 0.397899 accuracy: 0.875000
[29: 1/5] train loss: 0.075937 accuracy: 1.000000
[29: 2/5] train loss: 0.18

[52: 2/5] train loss: 0.032960 accuracy: 1.000000
[52: 3/5] train loss: 0.085482 accuracy: 0.968750
[52: 4/5] train loss: 0.242943 accuracy: 0.968750
[53: 0/5] train loss: 0.051889 accuracy: 0.968750
[53: 0/5] [94mtest[0m loss: 0.442904 accuracy: 0.875000
[53: 1/5] train loss: 0.029457 accuracy: 1.000000
[53: 2/5] train loss: 0.089077 accuracy: 0.968750
[53: 3/5] train loss: 0.127459 accuracy: 0.937500
[53: 4/5] train loss: 0.046809 accuracy: 1.000000
[54: 0/5] train loss: 0.100710 accuracy: 0.968750
[54: 0/5] [94mtest[0m loss: 0.209502 accuracy: 0.937500
[54: 1/5] train loss: 0.060190 accuracy: 1.000000
[54: 2/5] train loss: 0.028602 accuracy: 1.000000
[54: 3/5] train loss: 0.044758 accuracy: 1.000000
[54: 4/5] train loss: 0.086934 accuracy: 0.968750
[55: 0/5] train loss: 0.048024 accuracy: 1.000000
[55: 0/5] [94mtest[0m loss: 0.561433 accuracy: 0.843750
[55: 1/5] train loss: 0.034016 accuracy: 1.000000
[55: 2/5] train loss: 0.138947 accuracy: 0.937500
[55: 3/5] train loss: 0.03

[78: 4/5] train loss: 0.037943 accuracy: 1.000000
[79: 0/5] train loss: 0.017052 accuracy: 1.000000
[79: 0/5] [94mtest[0m loss: 0.124736 accuracy: 0.968750
[79: 1/5] train loss: 0.014125 accuracy: 1.000000
[79: 2/5] train loss: 0.018276 accuracy: 1.000000
[79: 3/5] train loss: 0.019508 accuracy: 1.000000
[79: 4/5] train loss: 0.014383 accuracy: 1.000000
[80: 0/5] train loss: 0.015893 accuracy: 1.000000
time spent in recent ten epoches is 0.245916 s
[80: 0/5] [94mtest[0m loss: 0.274078 accuracy: 0.937500
[80: 1/5] train loss: 0.016306 accuracy: 1.000000
[80: 2/5] train loss: 0.014031 accuracy: 1.000000
[80: 3/5] train loss: 0.029880 accuracy: 1.000000
[80: 4/5] train loss: 0.015650 accuracy: 1.000000
[81: 0/5] train loss: 0.020410 accuracy: 1.000000
[81: 0/5] [94mtest[0m loss: 0.182522 accuracy: 0.968750
[81: 1/5] train loss: 0.014470 accuracy: 1.000000
[81: 2/5] train loss: 0.025996 accuracy: 1.000000
[81: 3/5] train loss: 0.016718 accuracy: 1.000000
[81: 4/5] train loss: 0.01752

[105: 0/5] train loss: 0.014518 accuracy: 1.000000
[105: 0/5] [94mtest[0m loss: 0.330886 accuracy: 0.937500
[105: 1/5] train loss: 0.021058 accuracy: 1.000000
[105: 2/5] train loss: 0.024661 accuracy: 1.000000
[105: 3/5] train loss: 0.021961 accuracy: 1.000000
[105: 4/5] train loss: 0.013096 accuracy: 1.000000
[106: 0/5] train loss: 0.018645 accuracy: 1.000000
[106: 0/5] [94mtest[0m loss: 0.295794 accuracy: 0.906250
[106: 1/5] train loss: 0.016754 accuracy: 1.000000
[106: 2/5] train loss: 0.017202 accuracy: 1.000000
[106: 3/5] train loss: 0.011159 accuracy: 1.000000
[106: 4/5] train loss: 0.013094 accuracy: 1.000000
[107: 0/5] train loss: 0.016967 accuracy: 1.000000
[107: 0/5] [94mtest[0m loss: 0.413158 accuracy: 0.875000
[107: 1/5] train loss: 0.011457 accuracy: 1.000000
[107: 2/5] train loss: 0.010610 accuracy: 1.000000
[107: 3/5] train loss: 0.016137 accuracy: 1.000000
[107: 4/5] train loss: 0.016755 accuracy: 1.000000
[108: 0/5] train loss: 0.019376 accuracy: 1.000000
[108: 0

[130: 3/5] train loss: 0.011445 accuracy: 1.000000
[130: 4/5] train loss: 0.014551 accuracy: 1.000000
[131: 0/5] train loss: 0.013229 accuracy: 1.000000
[131: 0/5] [94mtest[0m loss: 0.240234 accuracy: 0.937500
[131: 1/5] train loss: 0.020687 accuracy: 1.000000
[131: 2/5] train loss: 0.010412 accuracy: 1.000000
[131: 3/5] train loss: 0.012397 accuracy: 1.000000
[131: 4/5] train loss: 0.012759 accuracy: 1.000000
[132: 0/5] train loss: 0.009534 accuracy: 1.000000
[132: 0/5] [94mtest[0m loss: 0.065227 accuracy: 0.968750
[132: 1/5] train loss: 0.014268 accuracy: 1.000000
[132: 2/5] train loss: 0.022800 accuracy: 1.000000
[132: 3/5] train loss: 0.011123 accuracy: 1.000000
[132: 4/5] train loss: 0.016419 accuracy: 1.000000
[133: 0/5] train loss: 0.009545 accuracy: 1.000000
[133: 0/5] [94mtest[0m loss: 0.524640 accuracy: 0.875000
[133: 1/5] train loss: 0.016295 accuracy: 1.000000
[133: 2/5] train loss: 0.013223 accuracy: 1.000000
[133: 3/5] train loss: 0.011348 accuracy: 1.000000
[133: 4

[156: 2/5] train loss: 0.028581 accuracy: 1.000000
[156: 3/5] train loss: 0.011361 accuracy: 1.000000
[156: 4/5] train loss: 0.019148 accuracy: 1.000000
[157: 0/5] train loss: 0.011673 accuracy: 1.000000
[157: 0/5] [94mtest[0m loss: 0.247852 accuracy: 0.937500
[157: 1/5] train loss: 0.012435 accuracy: 1.000000
[157: 2/5] train loss: 0.011848 accuracy: 1.000000
[157: 3/5] train loss: 0.013060 accuracy: 1.000000
[157: 4/5] train loss: 0.013529 accuracy: 1.000000
[158: 0/5] train loss: 0.010607 accuracy: 1.000000
[158: 0/5] [94mtest[0m loss: 0.060214 accuracy: 0.968750
[158: 1/5] train loss: 0.016200 accuracy: 1.000000
[158: 2/5] train loss: 0.011640 accuracy: 1.000000
[158: 3/5] train loss: 0.010039 accuracy: 1.000000
[158: 4/5] train loss: 0.012374 accuracy: 1.000000
[159: 0/5] train loss: 0.015866 accuracy: 1.000000
[159: 0/5] [94mtest[0m loss: 0.390578 accuracy: 0.906250
[159: 1/5] train loss: 0.012780 accuracy: 1.000000
[159: 2/5] train loss: 0.009492 accuracy: 1.000000
[159: 3

[182: 1/5] train loss: 0.037569 accuracy: 1.000000
[182: 2/5] train loss: 0.009905 accuracy: 1.000000
[182: 3/5] train loss: 0.010576 accuracy: 1.000000
[182: 4/5] train loss: 0.013491 accuracy: 1.000000
[183: 0/5] train loss: 0.014110 accuracy: 1.000000
[183: 0/5] [94mtest[0m loss: 0.651536 accuracy: 0.843750
[183: 1/5] train loss: 0.014446 accuracy: 1.000000
[183: 2/5] train loss: 0.015157 accuracy: 1.000000
[183: 3/5] train loss: 0.011098 accuracy: 1.000000
[183: 4/5] train loss: 0.010116 accuracy: 1.000000
[184: 0/5] train loss: 0.024587 accuracy: 1.000000
[184: 0/5] [94mtest[0m loss: 0.359982 accuracy: 0.906250
[184: 1/5] train loss: 0.011519 accuracy: 1.000000
[184: 2/5] train loss: 0.009945 accuracy: 1.000000
[184: 3/5] train loss: 0.015108 accuracy: 1.000000
[184: 4/5] train loss: 0.011111 accuracy: 1.000000
[185: 0/5] train loss: 0.011102 accuracy: 1.000000
[185: 0/5] [94mtest[0m loss: 0.551683 accuracy: 0.875000
[185: 1/5] train loss: 0.010564 accuracy: 1.000000
[185: 2

[208: 1/5] train loss: 0.009970 accuracy: 1.000000
[208: 2/5] train loss: 0.009567 accuracy: 1.000000
[208: 3/5] train loss: 0.013892 accuracy: 1.000000
[208: 4/5] train loss: 0.009888 accuracy: 1.000000
[209: 0/5] train loss: 0.009317 accuracy: 1.000000
[209: 0/5] [94mtest[0m loss: 0.428813 accuracy: 0.906250
[209: 1/5] train loss: 0.009725 accuracy: 1.000000
[209: 2/5] train loss: 0.011305 accuracy: 1.000000
[209: 3/5] train loss: 0.092640 accuracy: 0.968750
[209: 4/5] train loss: 0.009256 accuracy: 1.000000
[210: 0/5] train loss: 0.013445 accuracy: 1.000000
time spent in recent ten epoches is 0.246912 s
[210: 0/5] [94mtest[0m loss: 0.435643 accuracy: 0.906250
[210: 1/5] train loss: 0.012339 accuracy: 1.000000
[210: 2/5] train loss: 0.010769 accuracy: 1.000000
[210: 3/5] train loss: 0.010402 accuracy: 1.000000
[210: 4/5] train loss: 0.010296 accuracy: 1.000000
[211: 0/5] train loss: 0.010812 accuracy: 1.000000
[211: 0/5] [94mtest[0m loss: 0.590647 accuracy: 0.875000
[211: 1/5] 

[234: 0/5] train loss: 0.009929 accuracy: 1.000000
[234: 0/5] [94mtest[0m loss: 0.297506 accuracy: 0.906250
[234: 1/5] train loss: 0.010872 accuracy: 1.000000
[234: 2/5] train loss: 0.013886 accuracy: 1.000000
[234: 3/5] train loss: 0.009820 accuracy: 1.000000
[234: 4/5] train loss: 0.011425 accuracy: 1.000000
[235: 0/5] train loss: 0.014513 accuracy: 1.000000
[235: 0/5] [94mtest[0m loss: 0.537373 accuracy: 0.875000
[235: 1/5] train loss: 0.014504 accuracy: 1.000000
[235: 2/5] train loss: 0.013368 accuracy: 1.000000
[235: 3/5] train loss: 0.010732 accuracy: 1.000000
[235: 4/5] train loss: 0.014711 accuracy: 1.000000
[236: 0/5] train loss: 0.009927 accuracy: 1.000000
[236: 0/5] [94mtest[0m loss: 0.381672 accuracy: 0.906250
[236: 1/5] train loss: 0.010759 accuracy: 1.000000
[236: 2/5] train loss: 0.009573 accuracy: 1.000000
[236: 3/5] train loss: 0.011929 accuracy: 1.000000
[236: 4/5] train loss: 0.012248 accuracy: 1.000000
[237: 0/5] train loss: 0.014263 accuracy: 1.000000
[237: 0

In [9]:
total_correct = 0
total_testset = 0
for i,data in tqdm(enumerate(testdataloader, 0)):
    points, target = data
    target = target[:, 0]
    points = points.transpose(2, 1)
    points, target = points.cuda(), target.cuda()
    classifier = classifier.eval()
    pred, _, _ = classifier(points)
    pred_choice = pred.data.max(1)[1]
    correct = pred_choice.eq(target.data).cpu().sum()
    total_correct += correct.item()
    total_testset += points.size()[0]

print("final accuracy {}".format(total_correct / float(total_testset)))

1it [00:00,  7.07it/s]

final accuracy 0.875





In [10]:
total_correct = 0
total_testset = 0
classifier.load_state_dict(torch.load('./cls/cls_model_80.pth'))
dataiter = iter(testdataloader)
points, target = dataiter.next()
target = target[:, 0]
print(target)
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred, _, _ = classifier(points)
print(pred)
pred_choice = pred.data.max(1)[1]
print(pred_choice)
correct = pred_choice.eq(target.data).cpu().sum()
total_correct += correct.item()
total_testset += points.size()[0]

print(total_correct)
print(total_testset)
print("final accuracy {}".format(total_correct / float(total_testset)))

tensor([1, 0, 1, 2, 0, 2, 0, 2, 2, 1, 1, 0, 1, 0, 2, 2, 0, 2, 0, 2, 1, 1, 2, 2,
        0, 2, 2, 0, 1, 0, 1, 1])
tensor([[-5.5859e+00, -4.8733e-03, -6.8026e+00],
        [-2.0904e-01, -2.8511e+00, -2.0337e+00],
        [-6.4583e+00, -2.0617e-03, -7.6166e+00],
        [-5.1048e+00, -5.8585e+00, -8.9635e-03],
        [-5.5449e+00, -6.1972e+00, -5.9602e-03],
        [-6.6444e+00, -7.4678e+00, -1.8742e-03],
        [-9.5414e-03, -5.3376e+00, -5.3627e+00],
        [-5.1128e+00, -5.9481e+00, -8.6672e-03],
        [-5.7402e-03, -5.8669e+00, -5.8458e+00],
        [-6.2760e+00, -2.4968e-03, -7.3975e+00],
        [-1.3898e+00, -2.3962e+00, -4.1579e-01],
        [-6.7231e-01, -2.3322e+00, -9.3550e-01],
        [-4.5524e+00, -1.3319e-02, -5.9185e+00],
        [-1.1808e-02, -5.1695e+00, -5.1075e+00],
        [-3.9684e+00, -4.6335e+00, -2.9042e-02],
        [-7.9014e-01, -2.8133e+00, -7.2111e-01],
        [-5.1793e-03, -6.0702e+00, -5.8586e+00],
        [-5.5677e+00, -6.1820e+00, -5.9031e-03],
     