# Dynamic Graph CNN
Francesco Saverio Zuppichini

Implementation of [Dynamic Graph CNN for Learning on Point Clouds](https://arxiv.org/abs/1801.07829) for classification


## Data loading
Let's get the dataset

In [2]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [11]:
import torch
from torch_geometric.datasets import ModelNet
import torch_geometric.transforms as T
import time
from tqdm import tqdm_notebook

device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=1)

In [39]:
pre_transform = T.NormalizeScale()
transform = T.Compose([T.SamplePoints(1024),
                       T.RandomRotate(30), 
                       T.RandomScale((0.5,2)), 
                       ])
name = '40'

train_ds = ModelNet(root='./',
             train=True,
             name=name,
             pre_transform=pre_transform,
             transform=transform)

test_ds = ModelNet(root='./',
             train=True,
             name=name,
             pre_transform=pre_transform,
             transform = T.SamplePoints(1024 * 2))

Now we have to define our dataloader, these guys will handle the thread queue to feed the GPU

In [40]:
from torch_geometric.data import DataLoader

train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)

test_dl = DataLoader(test_ds, batch_size=16)

## Model

Define our architecture

In [43]:
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
from torch.nn import Sequential, Linear, ReLU, BatchNorm2d
from torch_geometric.nn import EdgeConv, knn_graph, global_max_pool
from torch_geometric.nn import knn_graph
from torch.nn import Sequential as Seq, Linear as Lin, ReLU, BatchNorm1d
import torch.nn.functional as F


# BatchNorm1d does not work well with our images dataset, uncomment it to load the pretrain weights
# for the ModelNet40
class DynamicEdgeConv(gnn.EdgeConv):
    def __init__(self, k = 6, * args, ** kwargs):
        super().__init__( * args, ** kwargs)
        self.k = k

    def forward(self, pos, batch):
        edge_index = knn_graph(pos, self.k, batch, loop = False)
        return super().forward(pos, edge_index)

def fc_block(in_features, out_features):
    return Seq(
        Linear(in_features, out_features),
        BatchNorm1d(out_features),
        ReLU(inplace=True))

class DGCNNClassification(nn.Module):
    def __init__(self, in_channels, n_classes, k = 20):
        super(DGCNNClassification, self).__init__()

        self.convs = nn.ModuleList([
            DynamicEdgeConv(
                k = k,
                nn = Sequential(
                    fc_block(in_channels * 2, 64),
                    fc_block(64, 64),
                    fc_block(64, 64),
                ),
                aggr = 'max'),
            DynamicEdgeConv(
                k = k,
                nn = fc_block(64 * 2, 128),          
                aggr = 'max')
        ])

#         self.point_wise_features2higher_dim = fc_block(128, 512) # uncomment to try to bottleneck DGCNN 
        self.point_wise_features2higher_dim = fc_block(128, 1024)           

        self.tail = nn.Sequential(
            fc_block(1024, 512), # commet it to try to bottleneck DGCNN  
            fc_block(512, 256), 
            fc_block(256, n_classes)      
        )

        self.k = k

        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.BatchNorm1d): # very similar to resnet
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x, batch):
        out = x 
        for conv in self.convs:
            out = conv(out, batch)# it can be updated to use skipped connection 
        out = self.point_wise_features2higher_dim(out)
        out = global_max_pool(out, batch)
        out = self.tail(out)
        
        return out

## Training

In [5]:
def get_save_dir(name):
    now = time.time()
    save_dir = './model-{}-{}.pt'.format(name, str(now).replace('.', '-'))
    return save_dir

In [None]:
from torch.optim import Adam

model = DGCNNClassification(3,40).to(device)


optimizer = Adam(model.parameters(), 0.001)
criterion = nn.CrossEntropyLoss()

EPOCHS = 50

In [6]:
def run(epochs, dl, save_dir, train=True):
    bar = tqdm_notebook(range(epochs))
    last_acc = 0
    
    for epoch in bar:
        acc_tot = 0
        if (epoch + 1) % 10 == 0: 
            for g in optimizer.param_groups:
                g['lr'] = g['lr'] * 0.2
        bbar = tqdm_notebook(dl, leave=False)
        for i, data in enumerate(bbar):
            start = time.time()
            if train: optimizer.zero_grad()
            data = data.to(device)
            out = model(data.pos, data.batch)
            preds = torch.argmax(out, dim=-1)
            acc = (data.y == preds).float().sum() / preds.shape[0]
            acc_v = acc.cpu().item()
            acc_tot += acc_v
            loss = criterion(out, data.y)
            if train:
                loss.backward()
                optimizer.step()
            bbar.set_description('[INFO] loss={:.2f} acc={:.2f}'.format(loss, acc_v))
        mean_acc = acc_tot / i
        if train:
            if mean_acc > last_acc:
                last_acc = mean_acc
                torch.save(model.state_dict(), save_dir)

        bar.set_description('[INFO] acc={:.3f} best={:.3f}'.format(mean_acc, last_acc))

In [None]:
run(50, train_dl, get_save_dir('40'), train=True)

In [45]:
model = DGCNNClassification(3,40).to(device)
model.load_state_dict(torch.load('./model-40-1559918906-8202357.pt'))
model.eval()
run(1, test_dl, 'tmp', train=False)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=616), HTML(value='')))

## Traversability Estimation

### Loading Data

We have to convert each image to a graph, we can use `grid` to get the correct graph values and k-nn to reduce its dimensions. The following loads our custom dataset into memory.

In [7]:
import glob
import pandas as pd
import cv2
import numpy as np 

from torch.utils.data import Dataset, ConcatDataset
from torch_geometric.data import Data
from torch_geometric.transforms import KNNGraph
from torch_geometric.utils import grid

class TraversabilityDataset(Dataset):
    def __init__(self, df, hm,
                 patches_dir,
                 patch_size=None,
                 tr=None,
                 time_window=None,
                 transform=None,
                 more_than=None,
                 less_than=None,
                 down_sampling=None,
                 transform_with_label=None,
                 ):

        self.df = df
        self.hm = hm
        self.patches_dir = patches_dir
        self.patch_size = patch_size
        self.tr = tr
        self.time_window = time_window

        self.transform = transform
        self.transform_with_label = transform_with_label
        self.should_generate_paths = not 'images' in df

        if 'advancement' not in self.df:
            self.df = add_advancement(self.df, time_window)

        if down_sampling is not None:
            self.df = self.df[::down_sampling]

        if more_than is not None: self.df = self.df[self.df['advancement'] >= more_than]
        if less_than is not None: self.df = self.df[self.df['advancement'] <= less_than]
        if tr is not None and len(self.df) > 0:
            self.df["label"] = self.df["advancement"] > tr

        if tr is None:
            self.df["advancement"][self.df["advancement"] < 0] = 0


    def read_patch(self, img_name):
        patch = cv2.imread(self.patches_dir + '/' + img_name)
        patch = cv2.cvtColor(patch, cv2.COLOR_BGR2GRAY)
        patch = cv2.resize(patch, (patch.shape[-2] // 2, patch.shape[-1] // 2)) 
        patch = patch.astype(np.float32)
        patch /= 255
        return patch

    def generate_patch(self, row):
        patch = hmpatch(self.hm, row["hm_x"], row["hm_y"], np.rad2deg(row['pose__pose_e_orientation_z']),
                        self.patch_size,
                        scale=1)[0]

        return patch

    def __getitem__(self, idx):
        row = self.df.iloc[int(idx)]

        if self.should_generate_paths:
            patch = self.generate_patch(row)
        else:
            patch = self.read_patch(row['images'])

        y = row['advancement']

        if 'label' in self.df:
            y = row['label'].astype(np.long)

        if 'height' in row:
            patch *= row['height']

        y = torch.tensor(y)

        if 'label' in self.df and self.transform_with_label is not None:
            patch = self.transform_with_label(patch, row['label'])

        return self.transform(patch), y

    def __len__(self):
        return len(self.df)

    @classmethod
    def from_meta(cls, meta, base_dir, hm_dir, n=None, *args, **kwargs):
        datasets = []

        for (idx, row) in meta.iterrows():
            try:
                df, hm = open_df_and_hm_from_meta_row(row, base_dir, hm_dir)
            except FileNotFoundError:
                continue
            if len(df) > 0: datasets.append(cls(df, hm, *args, **kwargs))
        if n is not None: datasets = datasets[:n]
        concat_ds = ConcatDataset(datasets)
        concat_ds.c = 2
        concat_ds.classes = 'False', 'True'

        return concat_ds

    @staticmethod
    def concat_dfs(concat_ds):
        df = None
        for ds in concat_ds.datasets:
            if df is None:
                df = ds.df
            else:
                df = pd.concat([df, ds.df], sort=True)
        df = df.reset_index(drop=True)
        concat_ds.df = df
        return concat_ds

    @classmethod
    def from_root(cls, root, n=None, *args, **kwargs):
        dfs_paths = glob.glob(root + '/*.csv')
        if len(dfs_paths) == 0: dfs_paths = glob.glob(root + '/**/*.csv')
        datasets = []
        for df_path in dfs_paths:
            df = pd.read_csv(df_path)
            if len(df) > 0:
                datasets.append(cls(df, root, *args, **kwargs))
        if n is not None: datasets = datasets[:n]

        concat_ds = ConcatDataset(datasets)
        concat_ds.c = 2
        concat_ds.classes = 'False', 'True'

        return concat_ds

    @classmethod
    def from_dfs(cls, dfs, root, *args, **kwargs):
        datasets = []

        for df in dfs:
            if len(df) > 0:
                datasets.append(cls(df, root, *args, **kwargs))

        concat_ds = ConcatDataset(datasets)
        concat_ds.c = 2
        concat_ds.classes = 'False', 'True'
        return concat_ds

class CenterAndScalePatch():
    """
    This class is used to center in the middle and rescale a given
    patch. We need to center in the  middle in order to
    decouple the root position from the classification task. Also,
    depending on the map, we need to multiply the patch by a scaling factor.
    """

    def __init__(self, scale=1.0, debug=False, ):
        self.scale = scale
        self.debug = debug

    def show_heatmap(self, x, title, ax):
        ax.set_title(title)
        img_n = x
        sns.heatmap(img_n,
                    ax=ax,
                    fmt='0.2f')

    def __call__(self, x, debug=False):
        if self.debug: fig = plt.figure()

        if self.debug:
            ax = plt.subplot(2, 2, 1)
            self.show_heatmap(x, 'original', ax)

        x *= self.scale
        if self.debug:
            ax = plt.subplot(2, 2, 2)
            self.show_heatmap(x, 'scale', ax)
        center = x[x.shape[0] // 2, x.shape[1] // 2]
        x -= center

        if self.debug:
            ax = plt.subplot(2, 2, 3)
            self.show_heatmap(x, 'centered {}'.format(center), ax)

        if self.debug:
            ax = plt.subplot(2, 2, 4)
            self.show_heatmap(x, 'final', ax)

        if self.debug: plt.show()
        return x

    

### Graphs from patches
To create the graphs we can subclass TraversabilityDataset and on the fly use the `grid` function to make them graphs

In [8]:
class TraversabilityGraphDataset(TraversabilityDataset):
    def __init__(self,*args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def __getitem__(self, idx):
        img, y = super().__getitem__(idx)
        img = img.squeeze()
        edge_index, pos = grid(img.shape[-2], img.shape[-1])
        pos = torch.cat([pos.float(), img.reshape(img.shape[-2] * img.shape[-1], 1)], dim=-1)
        # we set it pos since in the .run function we where using the 3d point pos as feature
        data = Data(pos=pos.float(), y=y.item())
        return data


In [36]:
from torchvision.transforms import ToTensor, Compose
from torch.nn import Dropout

TRAIN_ROOT = '/home/fzuppic/data/test/'
TEST_ROOT = '/home/fzuppic/data/test/100/'

train_tr = Compose([CenterAndScalePatch(), ToTensor(), Dropout(0.1)])
test_tr = Compose([CenterAndScalePatch(), ToTensor()])

train_ds = TraversabilityGraphDataset.from_root(TRAIN_ROOT, 
                                                patches_dir=TRAIN_ROOT + '/patches/', 
                                                down_sampling=2,
                                                transform=train_tr, 
                                                tr=0.2)

test_ds = TraversabilityGraphDataset.from_root(TEST_ROOT, 
                                                patches_dir=TEST_ROOT + '/patches/', 
                                                down_sampling=2,
                                                transform=train_tr, 
                                                tr=0.2)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy


In [37]:
from torch.optim import Adam
from torch_geometric.data import DataLoader

model = DGCNNClassification(3,2, k=6).to(device)

optimizer = Adam(model.parameters(), 0.0001)
criterion = nn.CrossEntropyLoss()

EPOCHS = 50
train_dl = DataLoader(train_ds, batch_size=16, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=16, shuffle=True)


In [None]:
run(50, train_dl, get_save_dir('traversability'), train=True)

In [38]:
model = DGCNNClassification(3,2).to(device)
model.load_state_dict(torch.load('./model-traversability-1560089226-7143812.pt'))
model.eval()
run(1, test_dl, None, train=False)

HBox(children=(IntProgress(value=0, max=1), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1130), HTML(value='')))