<a href="https://colab.research.google.com/github/atp1988/PointCloud/blob/main/point_cloud_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Fri Jul  2 11:43:58 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.27       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   40C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Imports

In [None]:
import numpy as np
import math
import random
import os
import torch
import scipy.spatial.distance

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, utils

import plotly.graph_objects as go
import plotly.express as px

In [None]:
!pip install torchmetrics



In [None]:
import torchmetrics

In [None]:
!pip install path.py

Collecting path.py
  Downloading https://files.pythonhosted.org/packages/8f/04/130b7a538c25693c85c4dee7e25d126ebf5511b1eb7320e64906687b159e/path.py-12.5.0-py3-none-any.whl
Collecting path
  Downloading https://files.pythonhosted.org/packages/a7/d6/68e26fb6204e280c770deffcd3062a380b769436740196a26fbcc7f6e2d4/path-16.0.0-py3-none-any.whl
Installing collected packages: path, path.py
Successfully installed path-16.0.0 path.py-12.5.0


In [None]:
from path import Path

# Download Dataset

In [None]:
!wget http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip

--2021-07-02 11:47:30--  http://3dvision.princeton.edu/projects/2014/3DShapeNets/ModelNet10.zip
Resolving 3dvision.princeton.edu (3dvision.princeton.edu)... 128.112.136.61
Connecting to 3dvision.princeton.edu (3dvision.princeton.edu)|128.112.136.61|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 473402300 (451M) [application/zip]
Saving to: ‘ModelNet10.zip’


2021-07-02 11:47:41 (40.2 MB/s) - ‘ModelNet10.zip’ saved [473402300/473402300]



In [None]:
!unzip -q ModelNet10.zip

# Load Dataset

In [None]:
path = Path("ModelNet10")
path

Path('ModelNet10')

In [None]:
path/'bed'

Path('ModelNet10/bed')

## Read off files

In [None]:
def read_off(file):
    if 'OFF' != file.readline().strip():
        raise('Not a valid OFF header')
    n_verts, n_faces, __ = tuple([int(s) for s in file.readline().strip().split(' ')])
    verts = [[float(s) for s in file.readline().strip().split(' ')] for i_vert in range(n_verts)]
    faces = [[int(s) for s in file.readline().strip().split(' ')][1:] for i_face in range(n_faces)]
    return verts, faces

In [None]:
with open(path/"chair/train/chair_0001.off", 'r') as f:
  verts, faces = read_off(f)

i,j,k = np.array(faces).T
x,y,z = np.array(verts).T
len(x)

2382

In [None]:
x

array([-9.6995, -9.6965, -9.6995, ...,  9.3055,  9.3055, -9.1335])

## Visualize

In [None]:
def visualize_rotate(data):
    x_eye, y_eye, z_eye = 1.25, 1.25, 0.8
    frames=[]

    def rotate_z(x, y, z, theta):
        w = x+1j*y
        return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z

    for t in np.arange(0, 10.26, 0.1):
        xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)
        frames.append(dict(layout=dict(scene=dict(camera=dict(eye=dict(x=xe, y=ye, z=ze))))))
    fig = go.Figure(data=data,
                    layout=go.Layout(
                        updatemenus=[dict(type='buttons',
                                    showactive=False,
                                    y=1,
                                    x=0.8,
                                    xanchor='left',
                                    yanchor='bottom',
                                    pad=dict(t=45, r=10),
                                    buttons=[dict(label='Play',
                                                    method='animate',
                                                    args=[None, dict(frame=dict(duration=50, redraw=True),
                                                                    transition=dict(duration=0),
                                                                    fromcurrent=True,
                                                                    mode='immediate'
                                                                    )]
                                                    )
                                            ]
                                    )
                                ]
                    ),
                    frames=frames
            )

    return fig

In [None]:
visualize_rotate([go.Mesh3d(x=x, y=y, z=z, color='lightpink', opacity=0.50, i=i,j=j,k=k)]).show()

In [None]:
def pcshow(xs,ys,zs):
    data=[go.Scatter3d(x=xs, y=ys, z=zs,
                                   mode='markers')]
    fig = visualize_rotate(data)
    fig.update_traces(marker=dict(size=2,
                      line=dict(width=2,
                      color='DarkSlateGrey')),
                      selector=dict(mode='markers'))
    fig.show()


In [None]:
pcshow(x,y,z)

## Sample points

In [None]:
class PointSampler(object):
    def __init__(self, output_size):
        assert isinstance(output_size, int)
        self.output_size = output_size

    def triangle_area(self, pt1, pt2, pt3):
        side_a = np.linalg.norm(pt1 - pt2)
        side_b = np.linalg.norm(pt2 - pt3)
        side_c = np.linalg.norm(pt3 - pt1)
        s = 0.5 * ( side_a + side_b + side_c)
        return max(s * (s - side_a) * (s - side_b) * (s - side_c), 0)**0.5

    def sample_point(self, pt1, pt2, pt3):
        # barycentric coordinates on a triangle
        # https://mathworld.wolfram.com/BarycentricCoordinates.html
        s, t = sorted([random.random(), random.random()])
        f = lambda i: s * pt1[i] + (t-s)*pt2[i] + (1-t)*pt3[i]
        return (f(0), f(1), f(2))


    def __call__(self, mesh):
        verts, faces = mesh
        verts = np.array(verts)
        areas = np.zeros((len(faces)))

        for i in range(len(areas)):
            areas[i] = (self.triangle_area(verts[faces[i][0]],
                                           verts[faces[i][1]],
                                           verts[faces[i][2]]))

        sampled_faces = (random.choices(faces,
                                      weights=areas,
                                      cum_weights=None,
                                      k=self.output_size))

        sampled_points = np.zeros((self.output_size, 3))

        for i in range(len(sampled_faces)):
            sampled_points[i] = (self.sample_point(verts[sampled_faces[i][0]],
                                                   verts[sampled_faces[i][1]],
                                                   verts[sampled_faces[i][2]]))

        return sampled_points


In [None]:
pointcloud = PointSampler(1024)((verts, faces))  #You can also specify more points, for example 3000

In [None]:
pcshow(*pointcloud.T)

## Normalize

In [None]:
class Normalize(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2

        norm_pointcloud = pointcloud - np.mean(pointcloud, axis=0)
        norm_pointcloud /= np.max(np.linalg.norm(norm_pointcloud, axis=1))

        return  norm_pointcloud

In [None]:
norm_pointcloud = Normalize()(pointcloud)

In [None]:
pcshow(*norm_pointcloud.T)

## Augmentations

In [None]:
class RandRotation_z(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2

        theta = random.random() * 2. * math.pi
        rot_matrix = np.array([[ math.cos(theta), -math.sin(theta),    0],
                               [ math.sin(theta),  math.cos(theta),    0],
                               [0,                             0,      1]])

        rot_pointcloud = rot_matrix.dot(pointcloud.T).T
        return  rot_pointcloud

class RandomNoise(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2

        noise = np.random.normal(0, 0.02, (pointcloud.shape))

        noisy_pointcloud = pointcloud + noise
        return  noisy_pointcloud

In [None]:
rot_pointcloud = RandRotation_z()(norm_pointcloud)
noisy_rot_pointcloud = RandomNoise()(rot_pointcloud)

In [None]:
pcshow(*noisy_rot_pointcloud.T)

## ToTensor

In [None]:
class ToTensor(object):
    def __call__(self, pointcloud):
        assert len(pointcloud.shape)==2

        return torch.from_numpy(pointcloud)

## Transform

In [None]:
def default_transforms():
    return transforms.Compose([
                                PointSampler(1024),
                                Normalize(),
                                ToTensor()
                              ])

In [None]:
train_transforms = transforms.Compose([
                    PointSampler(1024),
                    Normalize(),
                    RandRotation_z(),
                    RandomNoise(),
                    ToTensor()
                    ])

## Dataset

In [None]:
class PointCloudData(Dataset):
    def __init__(self, root_dir, valid=False, folder="train", transform=default_transforms()):
        self.root_dir = root_dir
        folders = [dir for dir in sorted(os.listdir(root_dir)) if os.path.isdir(root_dir/dir)]
        self.classes = {folder: i for i, folder in enumerate(folders)}
        self.transforms = transform if not valid else default_transforms()
        self.valid = valid
        self.files = []
        for category in self.classes.keys():
            new_dir = root_dir/Path(category)/folder
            for file in os.listdir(new_dir):
                if file.endswith('.off'):
                    sample = {}
                    sample['pcd_path'] = new_dir/file
                    sample['category'] = category
                    self.files.append(sample)

    def __len__(self):
        return len(self.files)

    def __preproc__(self, file):
        verts, faces = read_off(file)
        if self.transforms:
            pointcloud = self.transforms((verts, faces))
        return pointcloud

    def __getitem__(self, idx):
        pcd_path = self.files[idx]['pcd_path']
        category = self.files[idx]['category']
        with open(pcd_path, 'r') as f:
            pointcloud = self.__preproc__(f)
        return {'pointcloud': pointcloud,
                'category': self.classes[category]}

In [None]:
train_ds = PointCloudData(path, transform=train_transforms)
test_ds = PointCloudData(path, valid=True, folder='test',transform=default_transforms())

In [None]:
print('Train dataset size: ', len(train_ds))
print('Valid dataset size: ', len(test_ds))
print('Number of classes: ', len(train_ds.classes))
print('Sample pointcloud shape: ', train_ds[0]['pointcloud'].size())

Train dataset size:  3991
Valid dataset size:  908
Number of classes:  10
Sample pointcloud shape:  torch.Size([1024, 3])


In [None]:
data = train_ds[2000]
data

{'category': 5, 'pointcloud': tensor([[ 0.0093, -0.0872,  0.5838],
         [-0.0220,  0.1198,  0.3731],
         [ 0.6656, -0.0288, -0.3183],
         ...,
         [ 0.2325,  0.0463,  0.2639],
         [ 0.0308, -0.1053, -0.4908],
         [-0.6780,  0.1861,  0.6971]], dtype=torch.float64)}

In [None]:
pcshow(*data['pointcloud'].T)

# Dataloader

In [None]:
train_loader = DataLoader(dataset=train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(dataset=test_ds, batch_size=64)

In [None]:
batch = next(iter(train_loader))
batch['pointcloud'].shape

torch.Size([32, 1024, 3])

# Model

In [None]:
!pip install "git+git://github.com/erikwijmans/Pointnet2_PyTorch.git#egg=pointnet2_ops&subdirectory=pointnet2_ops_lib"



In [None]:
import pointnet2_ops
from pointnet2_ops.pointnet2_utils import furthest_point_sample

## Sampling & Grouping

In [None]:
def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm；
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

In [None]:
def knn_point(nsample, xyz, new_xyz):
    """
    Input:
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
    return group_idx

In [None]:
def index_points(points, idx):
    """
    Input:
        points: input points data, [B, N, C]
        idx: sample index data, [B, S]
    Return:
        new_points:, indexed points data, [B, S, C]
    """
    device = points.device
    B = points.shape[0]
    view_shape = list(idx.shape)
    view_shape[1:] = [1] * (len(view_shape) - 1)
    repeat_shape = list(idx.shape)
    repeat_shape[0] = 1
    batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
    new_points = points[batch_indices, idx, :]
    return new_points

In [None]:
def sample_and_group(npoint, radius, nsample, xyz, points):
    """
    Input:
        npoint:
        radius:
        nsample:
        xyz: input points position data, [B, N, 3]
        points: input points data, [B, N, D]
    Return:
        new_xyz: sampled points position data, [B, npoint, nsample, 3]
        new_points: sampled points data, [B, npoint, nsample, 3+D]
    """
    B, N, C = xyz.shape
    S = npoint
    xyz = xyz.contiguous()

    fps_idx = furthest_point_sample(xyz, npoint).long() # [B, npoint]
    new_xyz = index_points(xyz, fps_idx)
    new_points = index_points(points, fps_idx)
    # new_xyz = xyz[:]
    # new_points = points[:]

    idx = knn_point(nsample, xyz, new_xyz)
    #idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
    grouped_points = index_points(points, idx)
    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
    new_points = torch.cat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
    return new_xyz, new_points

## MiniPointNet

In [None]:
class MiniPointNet(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(MiniPointNet, self).__init__()
    self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm1d(out_channels)
    self.conv2 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
    self.bn2 = nn.BatchNorm1d(out_channels)

  def forward(self, x): # [B, N, S, C] = [32, 512, 32, 128]
    b, n, s, c = x.size()
    x = x.permute(0, 1, 3, 2) # [32, 512, 128, 32]
    x = x.reshape(-1, c, s)  # [32*512, 128, 32]
    x = F.relu(self.bn1(self.conv1(x))) # LBR 1
    x = F.relu(self.bn2(self.conv2(x))) # LBR 2
    x = F.adaptive_max_pool1d(x, 1).squeeze(-1) # [32*512, 128]
    x = x.reshape(b, n, c) # [32, 512, 128]
    # x = x.permute(0, 2, 1) # [32, 128, 512]
    return x

In [None]:
MiniPointNet(128, 128)(torch.rand(32, 512, 32, 128)).shape

torch.Size([32, 512, 128])

## Transformer

In [None]:
class Attention(nn.Module):
  def __init__(self, channels):
    super(Attention, self).__init__()
    self.q_conv = nn.Conv1d(channels, channels//4, 1, bias=False)
    self.k_conv = nn.Conv1d(channels, channels//4, 1, bias=False)
    self.v_conv = nn.Conv1d(channels, channels, 1, bias=False)
    self.softmax = nn.Softmax(dim=-1)
    self.conv = nn.Conv1d(channels, channels, 1)
    self.bn = nn.BatchNorm1d(channels)

  def forward(self, x):
    x_q = self.q_conv(x).permute(0, 2, 1)
    x_k = self.k_conv(x)
    x_v = self.v_conv(x)
    w = torch.bmm(x_q, x_k)
    w = self.softmax(w)
    w /= (1e-6 + w.sum(dim=1, keepdim=True))
    x_r = torch.bmm(x_v, w)
    x_r = F.relu(self.bn(self.conv(x - x_r)))
    x += x_r
    return x

In [None]:
class Attention(nn.Module):
    def __init__(self, channels):
        super(Attention, self).__init__()
        self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
        self.q_conv.weight = self.k_conv.weight
        self.q_conv.bias = self.k_conv.bias

        self.v_conv = nn.Conv1d(channels, channels, 1)
        self.trans_conv = nn.Conv1d(channels, channels, 1)
        self.after_norm = nn.BatchNorm1d(channels)
        self.act = nn.ReLU()
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # b, n, c
        x_q = self.q_conv(x).permute(0, 2, 1)
        # b, c, n
        x_k = self.k_conv(x)
        x_v = self.v_conv(x)
        # b, n, n
        energy = torch.bmm(x_q, x_k)

        attention = self.softmax(energy)
        attention = attention / (1e-9 + attention.sum(dim=1, keepdim=True))
        # b, c, n
        x_r = torch.bmm(x_v, attention)
        x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
        x = x + x_r
        return x

In [None]:
class Transformer(nn.Module):
    def __init__(self, channels=256):
        super(Transformer, self).__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)

        self.bn1 = nn.BatchNorm1d(channels)
        self.bn2 = nn.BatchNorm1d(channels)

        self.sa1 = Attention(channels)
        self.sa2 = Attention(channels)
        self.sa3 = Attention(channels)
        self.sa4 = Attention(channels)

    def forward(self, x):
        #
        # b, 3, npoint, nsample
        # conv2d 3 -> 128 channels 1, 1
        # b * npoint, c, nsample
        # permute reshape
        batch_size, _, N = x.size()

        # B, D, N
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x1 = self.sa1(x)
        x2 = self.sa2(x1)
        x3 = self.sa3(x2)
        x4 = self.sa4(x3)
        x = torch.cat((x1, x2, x3, x4), dim=1)
        # x = torch.cat((x, x, x, x), dim=1)

        return x

## PCT

In [None]:
class PCT(nn.Module):
  def __init__(self):
    super(PCT, self).__init__()
    # Input embedding (Green)
    self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
    self.bn1 = nn.BatchNorm1d(64)
    self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
    self.bn2 = nn.BatchNorm1d(64)
    # Input embedding (Blue)
    self.mini1 = MiniPointNet(128, 128)
    self.mini2 = MiniPointNet(256, 256)
    # Transformer
    self.transformer = Transformer(256)
    # LBR
    self.conv3 = nn.Conv1d(1024, 1024, kernel_size=1, bias=False)
    self.bn3 = nn.BatchNorm1d(1024)
    # Classifier
    self.cls = nn.Sequential(nn.Linear(1024, 512),
                             nn.LayerNorm(512),
                             nn.LeakyReLU(negative_slope=0.2),
                             nn.Linear(512, 256),
                             nn.LayerNorm(256),
                             nn.LeakyReLU(negative_slope=0.2),
                             nn.Linear(256, 10))
  def forward(self, x):
    xyz = x.permute(0, 2, 1)
    # Input embedding (Green)
    x = F.relu(self.bn1(self.conv1(x))) # LBR 1
    x = F.relu(self.bn2(self.conv2(x))) # LBR 2
    # Input embedding (SG)
    x = x.permute(0, 2, 1)
    new_xyz, new_points = sample_and_group(npoint=512, radius=0.15, nsample=32, xyz=xyz, points=x)
    new_points = self.mini1(new_points)
    new_xyz, new_points = sample_and_group(npoint=320, radius=0.2, nsample=32, xyz=new_xyz, points=new_points)
    new_points = self.mini2(new_points) # [b, n, c]
    new_points = new_points.permute(0, 2, 1)
    feature = self.transformer(new_points)
    # TODO: Concat
    # LBR
    feature = F.relu(self.bn3(self.conv3(feature)))
    # Maxpool
    feature = F.adaptive_max_pool1d(feature, 1).squeeze(-1)
    # Classifier
    feature = self.cls(feature)

    return feature

In [None]:
model = PCT().cuda()

In [None]:
feature = model(torch.rand(2, 3, 1024).cuda())

# Config

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model.to(device)

PCT(
  (conv1): Conv1d(3, 64, kernel_size=(1,), stride=(1,), bias=False)
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv1d(64, 64, kernel_size=(1,), stride=(1,), bias=False)
  (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mini1): MiniPointNet(
    (conv1): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
    (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(128, 128, kernel_size=(1,), stride=(1,), bias=False)
    (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (mini2): MiniPointNet(
    (conv1): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
    (bn1): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (conv2): Conv1d(256, 256, kernel_size=(1,), stride=(1,), bias=False)
    (bn2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affin

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
Loss = nn.CrossEntropyLoss()

In [None]:
"""#AverageMeter"""
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

## Training loop

In [None]:
def train(model, train_loader, test_loader, epochs=15):
    for epoch in range(epochs):
        model.train()
        loss_total = AverageMeter()
        accuracy = torchmetrics.Accuracy().cuda()

        for i, data in enumerate(train_loader, 0):
            inputs, labels = data['pointcloud'].to(device).float(), data['category'].to(device)
            optimizer.zero_grad()
            outputs = model(inputs.transpose(1,2))
            loss = Loss(outputs, labels)
            loss.backward()
            optimizer.step()
            loss_total.update(loss)
            accuracy(outputs.softmax(dim=-1), labels)
            if i == 10:
              break
        acc = 100*accuracy.compute()

        model.eval()
        correct = total = 0
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                inputs, labels = data['pointcloud'].to(device).float(), data['category'].to(device)
                outputs = model(inputs.transpose(1,2))
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                if i == 10:
                  break
        test_acc = 100. * correct / total
        print('test accuracy: %d %%' % test_acc)

In [None]:
train(model, train_loader, test_loader)

test accuracy: 14 %


KeyboardInterrupt: ignored