In [1]:
import os
import sys
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from models.mesh_classifier import ClassifierModel

from data.data import ClassificationData, collate_fn
import utils.util as util

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
dataroot = 'datasets/human_class' 
ninput_edges = 40000
num_aug = 1

In [4]:
train_dataset = ClassificationData(dataroot=dataroot, phase='train', device=device, ninput_edges=ninput_edges, num_aug=num_aug)
train_dataloader = DataLoader(
    train_dataset,
    batch_size=1,
    shuffle=True,
    collate_fn=collate_fn
)

test_dataset = ClassificationData(dataroot=dataroot, phase='test', device=device, ninput_edges=ninput_edges, num_aug=num_aug)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn
)

nclasses = train_dataset.get_nclasses()   
ninput_channels = train_dataset.get_ninput_channels()

print('nclasses', nclasses)
print('ninput_channels', ninput_channels)

loaded mean / std from cache
loaded mean / std from cache
nclasses 2
ninput_channels 5


In [5]:
from models import networks

init_type = 'normal'
init_gain = 0.02

net = networks.MeshClassifier(ninput_channels, nclasses, ninput_edges)
net = networks.init_weights(net, init_type, init_gain)
net.to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.005, betas=(0.9, 0.999))

In [6]:
for epoch in range(0,10):
  
  net.train()
  for i, data in enumerate(train_dataloader):
    
    mesh = data['mesh']
    labels = torch.from_numpy(data['label']).long().to(device)
    edge_features = torch.from_numpy(data['edge_features']).float().to(device).requires_grad_(True)
    
    optimizer.zero_grad()
  
    out = net(edge_features, mesh)
    labels = labels.repeat(out.shape[0])
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    
    print(f'[Epoch {epoch+1} - {i+1}/{len(train_dataloader)}] loss: {loss.item()}')
  
  net.eval()
  with torch.no_grad():
    accuracy = 0
    for i, data in enumerate(test_dataloader):
      mesh = data['mesh']
      labels = torch.from_numpy(data['label']).long().to(device)
      edge_features = torch.from_numpy(data['edge_features']).float().to(device).requires_grad_(False)
      
      out = net(edge_features, mesh)
      labels = labels.repeat(out.shape[0])
      accuracy += (out.argmax(dim=1) == labels).float().mean()
    
    accuracy /= len(test_dataloader)
    print(f'[Epoch {epoch+1}] accuracy: {accuracy.item()}')

[Epoch 1 - 1/42] loss: 0.7163923382759094
[Epoch 1 - 2/42] loss: 0.31367039680480957
[Epoch 1 - 3/42] loss: 0.1931520700454712
[Epoch 1 - 4/42] loss: 4.96767520904541
[Epoch 1 - 5/42] loss: 3.309171199798584
[Epoch 1 - 6/42] loss: 0.2033335268497467
[Epoch 1 - 7/42] loss: 0.2754368185997009
[Epoch 1 - 8/42] loss: 0.9970077276229858
[Epoch 1 - 9/42] loss: 0.8054834008216858
[Epoch 1 - 10/42] loss: 0.6607300043106079
[Epoch 1 - 11/42] loss: 0.7123059034347534
[Epoch 1 - 12/42] loss: 0.6767155528068542
[Epoch 1 - 13/42] loss: 0.7075507044792175
[Epoch 1 - 14/42] loss: 0.6795034408569336
[Epoch 1 - 15/42] loss: 0.6801548004150391
[Epoch 1 - 16/42] loss: 0.705324649810791
[Epoch 1 - 17/42] loss: 0.6804910898208618
[Epoch 1 - 18/42] loss: 0.6802354454994202
[Epoch 1 - 19/42] loss: 0.6793195605278015
[Epoch 1 - 20/42] loss: 0.6777113080024719
[Epoch 1 - 21/42] loss: 0.7102762460708618
[Epoch 1 - 22/42] loss: 0.6746072769165039
[Epoch 1 - 23/42] loss: 0.7109529376029968
[Epoch 1 - 24/42] loss: