# Dynamic Graph CNN

## Data loading
Let's get the dataset

In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
import torch
from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T
import time
from tqdm import tqdm_notebook

pre_transform = T.NormalizeScale()
transform = T.Compose([T.SamplePoints(1024),
                       T.RandomRotate(30), 
                       T.RandomScale((0.5,2)), 
                       ])
name = '40'

train_ds = ModelNet(root='./',
             train=True,
             name=name,
             pre_transform=pre_transform,
             transform=transform)

test_ds = ModelNet(root='./',
             train=True,
             name=name,
             pre_transform=pre_transform,
             transform = T.SamplePoints(1024 * 4))

In [None]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device

Now we have to define our dataloader, these guys will handle the thread queue to feed the GPU

In [17]:
from torch_geometric.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=32, shuffle=True)

test_dl = DataLoader(test_ds, batch_size=8)

## Model

Define our architecture

In [5]:
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import EdgeConv, knn_graph, global_max_pool
from torch_geometric.nn import knn_graph
from torch.nn import Sequential as Seq, Linear as Lin, ReLU
import torch.nn.functional as F

class DynamicEdgeConv(gnn.EdgeConv):
    def __init__(self, k=6, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.k = k

    def forward(self, pos, batch):
        edge_index = knn_graph(pos, self.k, batch, loop=False)
        return super().forward(pos, edge_index)

class DGCNNClassification(nn.Module):
  def __init__(self, in_channels, n_classes, k=20):
    super(DGCNNClassification, self).__init__()

    self.convs = nn.ModuleList([
        DynamicEdgeConv(
            k=k,
            nn=Sequential(
              Linear(in_channels * 2, 64),
              ReLU(),
              Linear(64, 64),
              ReLU(),
              Linear(64, 64),
              ReLU()
            ), 
            aggr='max'),
        DynamicEdgeConv(
            k=k,
            nn=Sequential(
              Linear(64 * 2, 128),
              ReLU()
          ), 
        aggr='max') 
    ])
    
    
    self.point_wise_features2higher_dim = nn.Sequential(
        nn.Linear(128, 512),
        nn.ReLU()
    )
    
    self.tail = nn.Sequential(
#         nn.Linear(1024, 512),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, n_classes),

    )
    
    self.k = k
    
  def forward(self, x, batch):
      out = x
      for conv in self.convs:
        out = conv(out, batch) # this could be a Sequential 
      out = self.point_wise_features2higher_dim(out)
      out = global_max_pool(out, batch)
      out = self.tail(out)
      
      return out

## Training

In [6]:
save_dir = './model-{}-{}'.format(name, time.time())
save_dir

'./model-40-1558955287.9842556'

In [7]:
from torch.optim import Adam

model = DGCNNClassification(3,10).to(device)


optimizer = Adam(model.parameters(), 0.001)
criterion = nn.CrossEntropyLoss()

EPOCHS = 50

In [21]:
def run(epochs, dl, train=True):
    bar = tqdm_notebook(range(epochs))
    last_acc = 0
    
    for epoch in bar:
        acc_tot = 0
        if (epoch + 1) % 10 == 0: 
            for g in optimizer.param_groups:
                g['lr'] = g['lr'] * 0.2
        bbar = tqdm_notebook(dl, leave=False)
        for i, data in enumerate(bbar):
            start = time.time()
            if train: optimizer.zero_grad()
            data = data.to(device)
            out = model(data.pos, data.batch)
            preds = torch.argmax(out, dim=-1)
            acc = (data.y == preds).float().sum() / preds.shape[0]
            acc_v = acc.cpu().item()
            acc_tot += acc_v
            loss = criterion(out, data.y)
            if train:
                loss.backward()
                optimizer.step()
            bbar.set_description('[INFO] loss={:.2f} acc={:.2f}'.format(loss, acc_v))
        mean_acc = acc_tot / i
        if train:
            if mean_acc > last_acc:
                last_acc = mean_acc
                torch.save(model.state_dict(), save_dir)

        bar.set_description('[INFO] acc={:.3f} best={:.3f}'.format(mean_acc, last_acc))

In [None]:
run(40, train_dl, train=True)

In [22]:
model = DGCNNClassification(3,10).to(device)
model.load_state_dict(torch.load('./model-40-1558634785.3589494'))
model.eval()
run(1, test_dl, train=False)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=499), HTML(value='')))