# Code

## Code attributions

Pyg: 
- https://pytorch-geometric.readthedocs.io/en/latest/notes/colabs.html

Wandb:
- https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Simple_PyTorch_Integration.ipynb
- https://colab.research.google.com/github/wandb/examples/blob/master/colabs/pytorch/Organizing_Hyperparameter_Sweeps_in_PyTorch_with_W%26B.ipynb

## Dependencies and setup

In [1]:
#@title Dependencies

import os
import itertools
import torch
import copy
from tqdm.notebook import tqdm
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

%matplotlib inline
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(12345)

#wandb setup
!pip install wandb
import wandb
wandb.login()

1.12.0+cu116
[0mCollecting wandb
  Downloading wandb-0.13.3-py2.py3-none-any.whl (1.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m80.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting setproctitle
  Downloading setproctitle-1.3.2-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.9.8-py2.py3-none-any.whl (158 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m158.7/158.7 kB[0m [31m42.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting promise<3,>=2.0
  Downloading promise-2.3.tar.gz (19 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [3]:
#@title Utility functions
from collections import defaultdict

#Visualization methods for 3d shapes and point clouds
def visualize_mesh(pos, face):
    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([])
    ax.plot_trisurf(pos[:, 0], pos[:, 1], pos[:, 2], triangles=face.t(), antialiased=False)
    plt.show()


def visualize_points(pos, edge_index=None, index=None):
    fig = plt.figure(figsize=(4, 4))
    if edge_index is not None:
        for (src, dst) in edge_index.t().tolist():
             src = pos[src].tolist()
             dst = pos[dst].tolist()
             plt.plot([src[0], dst[0]], [src[1], dst[1]], linewidth=1, color='black')
    if index is None:
        plt.scatter(pos[:, 0], pos[:, 1], s=50, zorder=1000)
    else:
       mask = torch.zeros(pos.size(0), dtype=torch.bool)
       mask[index] = True
       plt.scatter(pos[~mask, 0], pos[~mask, 1], s=50, color='lightgray', zorder=1000)
       plt.scatter(pos[mask, 0], pos[mask, 1], s=50, zorder=1000)
    plt.axis('off')
    plt.show()
    
#other utility methods for later
def get_optimizer(opt_name, model, lr):
  #There is no match in python 3.7, which is the version colab uses :/
    if opt_name == "Adam":
        return torch.optim.Adam(model.parameters(), lr)
    elif opt_name == "SGD":
        return torch.optim.SGD(model.parameters(), lr, momentum=1e-4,
                           dampening=1e-6)
    return None
    
def shrink_ModelNet(dataset, max):
    datalist = []
    num_classes = defaultdict(lambda: 0)
    for data in dataset:
        class_num = int(data[0].y)
        if num_classes[class_num] < max:
            datalist.append(data)
            num_classes[class_num] += 1
    
    return datalist

In [4]:
#@title Additional configs
from torch_geometric.transforms import RandomRotate, Compose, SamplePoints, ToDevice

dataset = "ModelNet"

if dataset=="ModelNet":
    nr_classes=10
    batch_size=100
    max_item_per_class=101
    nr_points=1024
    training_ds_root = "training_data_1"
    test_ds_root = "test_data_1"
else:
    nr_classes=40
    batch_size=40
    nr_points=256
    max_item_per_class=1
    training_ds_root = "training_data",
    test_ds_root = "test_data"
    
#Augmentors and transformers for the data
augmentor = Compose([
    RandomRotate(degrees=90, axis=0),
    RandomRotate(degrees=90, axis=1),
    RandomRotate(degrees=90, axis=2)
])

transformer=Compose([
    SamplePoints(num=nr_points, include_normals=True),
    ToDevice(device)
])

#Create dictionary with hp 
config_contrastive=dict(
    dataset = dataset,
    max_item_per_class=max_item_per_class,
    epochs=130,
    classes=nr_classes,
    batch_size=batch_size,
    lr=0.01,
    temperature=0.03,
    optimizer="Adam",
    train_ds_root = training_ds_root,
    test_ds_root = test_ds_root,
    classifier_epochs=200,
    classifier_lr = 0.05
)

## Dataset and model

In [5]:
#@title Extend dataset with augmented samples

from torch_geometric.data import Dataset
from torch_geometric.datasets import GeometricShapes, ModelNet
from torch_geometric.loader import DataLoader
  
#Extend existing Class
class AugmentedDS(Dataset):

  """An augmented version of the GeometricShapes dataset"""

  def __init__(self, root: str, train:bool, augmentor, transformer, ds_name: str):
        
        if ds_name == "GeometricShapes":
            self.dataset = GeometricShapes(root = root, train = train)
        else:
            self.dataset = ModelNet(root=root, train=train, name='10')
        self.transformer = transformer
        self.augmentor_transformer = Compose([augmentor,transformer])

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):

    original_shape = self.dataset[idx]
    augmented_shape1 = self.augmentor_transformer(original_shape.clone()) 
    augmented_shape2 = self.augmentor_transformer(original_shape.clone())
    augmented_shape3 = self.augmentor_transformer(original_shape.clone()) 
    original_shape = self.transformer(original_shape)

    return original_shape, augmented_shape1, augmented_shape2, augmented_shape3


In [6]:
#@title Create simple PPFnet

from torch_geometric.nn import PPFConv, global_max_pool
from torch.nn import Sequential, Linear, ReLU
from torch_cluster import fps, knn_graph

class PPFNet(torch.nn.Module):
    def __init__(self, feat_dim, hidden_layer_dim):
        super().__init__()

        #Dimension of hidden layers
        self.hidden_layer_dim = hidden_layer_dim
        #Dimension of feature, 3 with 3D fd and 4 with 4D
        self.feat_dim = feat_dim

        
        mlp1 = Sequential(Linear(self.feat_dim, self.hidden_layer_dim),
                              ReLU(),
                              Linear(self.hidden_layer_dim, self.hidden_layer_dim*2))
        self.conv1 = PPFConv(local_nn = mlp1)  

        
        mlp2 = Sequential(Linear(self.hidden_layer_dim*2 + self.feat_dim, self.hidden_layer_dim*2 + self.feat_dim),
                              ReLU(),
                              Linear(self.hidden_layer_dim*2 + self.feat_dim, self.hidden_layer_dim*2 + self.feat_dim))  
        self.conv2 = PPFConv(local_nn = mlp2)  
        
        self.prj_head = Sequential(
            Linear(self.hidden_layer_dim*2 + self.feat_dim, 100),
            ReLU(),
            Linear(100,100)
        )
        
    def forward(self, pos, normal, batch):

        #Create edges in the point cloud
        edge_index = knn_graph(pos, k=16, batch=batch, loop=False)
        
        #There are no features in first convolution
        #Other datasets different from GS or MN may have them
        x = self.conv1(x=None, pos=pos, normal=normal, edge_index=edge_index)
        x = x.relu()
        
        if True:
            #farthest point sampling
            index = fps(pos, batch, ratio=0.5)
            x = x[index]
            pos = pos[index]
            normal = normal[index]
            batch = batch[index]
            edge_index = knn_graph(pos, k=16, batch=batch, loop=False)
        
        x = self.conv2(x=x, pos=pos, normal=normal, edge_index=edge_index)
        x = x.relu()
        
        x = global_max_pool(x, batch)  # [num_examples, hidden_channels]
        return self.prj_head(x)

In [7]:
def make_contrastive(config):
    # Make the data
    train_ds = AugmentedDS(root = config.train_ds_root, augmentor=augmentor,
                       transformer=transformer, train=True, ds_name = config.dataset)
    test_ds = AugmentedDS(root = config.test_ds_root, augmentor=augmentor,
                       transformer=transformer, train=False, ds_name = config.dataset)
    
    #Shrink dataset if it's modelnet
    if dataset == "ModelNet":
        datalist_train = shrink_ModelNet(train_ds, config.max_item_per_class)
        datalist_test = shrink_ModelNet(test_ds, config.max_item_per_class)

        train_dl = DataLoader(datalist_train, batch_size = config.batch_size, shuffle=True)
        test_dl = DataLoader(datalist_test, batch_size = config.batch_size, shuffle=True)
    else:
        train_dl = DataLoader(train_ds, batch_size = config.batch_size, shuffle=True)
        test_dl = DataLoader(test_ds, batch_size = config.batch_size, shuffle=True) 
    
    # Make the model
    model = PPFNet(feat_dim=4, hidden_layer_dim = 32).to(device)

    # Make optimizer
    optimizer = get_optimizer(opt_name = config.optimizer, model = model, lr=config.lr)
    
    return model, train_dl, test_dl, optimizer

## Training model

In [8]:
#@title InfoNCE loss implementations
from torch.nn import CosineSimilarity

cos = CosineSimilarity(dim=0)
cos2 = CosineSimilarity(dim=1)

def InfoNCELossSN(anchors, augmented, temperature=0.05):

  loss = 0
  batch_len = anchors.shape[0]

  for index in range(batch_len):
    anchor = anchors[index]
    pos_sample = augmented[index]
    pos_sim = torch.exp(cos(anchor, pos_sample) / temperature)
    neg_sim = 0
    
    if index == 0:
        negatives = augmented[index+1:, :]
    elif index == batch_len-1:
        negatives = augmented[:index, :]
    else:
        negatives = torch.cat((augmented[:index, :], augmented[index+1:, :]), dim=0)
    
    neg_sim = torch.einsum("i-> ", torch.exp(cos2(anchor, negatives) / temperature))
        
    loss += -torch.log(pos_sim / (pos_sim + neg_sim))
    
  return loss 

def InfoNCELoss(anchors, augmented1, augmented2, temperature=0.05):

  """InfoNCE with support to multiple positives"""
  loss = 0
  batch_len = anchors.shape[0]

  for index in range(batch_len):
    anchor = anchors[index]
    positives = (torch.cat((augmented1[index], augmented2[index]), dim=0)).reshape(2, -1)
    pos_sim = torch.einsum("i->",torch.exp(cos2(anchor, positives) / temperature))
    
    neg_sim = 0
    #Handle first and last row
    if index == 0:
        negatives = torch.cat((augmented1[index+1:, :], augmented2[index+1:, :]), dim=0)
    elif index == batch_len-1:
        negatives = torch.cat((augmented1[:index, :],augmented2[:index, :]),dim=0)
    else:
        n1 = torch.cat((augmented1[:index, :], augmented1[index+1:, :]), dim=0)
        n2 = torch.cat((augmented2[:index, :], augmented2[index+1:, :]), dim=0)
        negatives = torch.cat((n1, n2), dim=0)
         
    neg_sim = torch.einsum("i-> ", torch.exp(cos2(anchor, negatives) / temperature))
    
    loss += -torch.log(pos_sim / (pos_sim + neg_sim))

  return loss


In [9]:
#@title Training functions

def train_batch(model, optim, loader, temperature, scheduler=None):
    model.train()

    total_loss = 0
    for data in loader:
        optim.zero_grad()  # Clear gradients.
        
        #encode original and augmentations
        original, augmented1, augmented2 = data[0], data[1], data[2]
        z1 = model(original.pos, original.normal, original.batch)
        z2 = model(augmented1.pos, augmented1.normal, augmented1.batch)
        z3 = model(augmented2.pos, augmented2.normal, augmented2.batch)

        loss = InfoNCELoss(anchors=z1, augmented1=z2, augmented2=z3, 
                            temperature=temperature)
        #loss = InfoNCELossSN(anchors=z1, augmented=z2, temperature=temperature)
                            
        loss.backward()
        optim.step()
        total_loss += loss

    scheduler.step() if scheduler is not None else None
    return total_loss                 

def train_log(loss, epoch):
    # log to wandb
    wandb.log({"loss": loss}, step=epoch)
    
    
def train(model, loader, optimizer, config):
    # Tell wandb to watch what the model gets up to: gradients, weights, and more!
    wandb.watch(model, log="all", log_freq=10)
    
    for epoch in tqdm(range(config.epochs)):
        loss = train_batch(model, optimizer, loader, config.temperature)
        train_log(loss, epoch)   

In [13]:
#@title Test the result of contrastive training
#Note that the network has to be trained in order to be used for testing

class SimpleClassifier(torch.nn.Module):

  def __init__(self):
    super().__init__()

    self.mlp = torch.nn.Linear(in_features=100, out_features=10)
    self.activation = torch.nn.ReLU()

  def forward(self, x):
    return self.activation(self.mlp(x))


def train_classifier(model, classifier, optim, loader, scheduler=None):
    classifier.train()
    
    total_loss = 0
    loss_f = torch.nn.CrossEntropyLoss()
    for data in loader:

        optim.zero_grad()  # Clear gradients.
        original = data[0]
        z1 = model(original.pos, original.normal, original.batch)
        
        logits = classifier(z1)
        loss = loss_f(logits, data[0].y)
        loss.backward()  # Backward pass.
        
        optim.step()  # Update model parameters.
        total_loss += loss

    scheduler.step() if scheduler is not None else None
    return total_loss / len(loader.dataset)


@torch.no_grad()
def test_classifier(model, classifier, loader):
    classifier.eval()

    total_correct = 0
    for data in loader:

        original = data[0]
        z1 = model(original.pos, original.normal, original.batch)
        logits = classifier(z1)
        preds = torch.argmax(logits, dim=-1)
        total_correct += int((preds == data[0].y).sum())

    return total_correct/(len(loader.dataset))

def test(model, train_loader, test_loader, config):
    # Run the model on some test examples
    
    classifier = SimpleClassifier().to(device)
    optimizer = get_optimizer(opt_name = config.optimizer, model = classifier, lr=0.01)
    wandb.watch(classifier, log="all", log_freq=10)
    
    for epoch in tqdm(range(config.classifier_epochs)):
        loss = train_classifier(model, classifier, optimizer, train_loader)
        wandb.log({"classifier train loss": loss}, step=epoch)   
        
        test_acc = test_classifier(model, classifier, test_loader)
        wandb.log({"classifier test accuracy": test_acc}, step=epoch)

In [14]:
def model_pipeline(hyperparameters):
    with wandb.init(project="contrastive_training", config=hyperparameters):
        config = wandb.config
        # make the model, data, and optimization problem
        model, train_loader, test_loader, optimizer = make_contrastive(config)
        print(model)

        #and use them to train the model
        train(model, train_loader, optimizer, config)

        # and test its final performance
        test(model, train_loader, test_loader, config)

    return model

In [15]:
#@title Run pipeline

model = model_pipeline(config_contrastive)

[34m[1mwandb[0m: Currently logged in as: [33mmattewg_dev[0m. Use [1m`wandb login --relogin`[0m to force relogin


PPFNet(
  (conv1): PPFConv(local_nn=Sequential(
    (0): Linear(in_features=4, out_features=32, bias=True)
    (1): ReLU()
    (2): Linear(in_features=32, out_features=64, bias=True)
  ), global_nn=None)
  (conv2): PPFConv(local_nn=Sequential(
    (0): Linear(in_features=68, out_features=68, bias=True)
    (1): ReLU()
    (2): Linear(in_features=68, out_features=68, bias=True)
  ), global_nn=None)
  (prj_head): Sequential(
    (0): Linear(in_features=68, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=100, bias=True)
  )
)


  0%|          | 0/130 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
classifier test accuracy,▃█▆▃▅▃▄▂▂▃▄▁▃█▄▄▁▅▆▅▇▂▅▅▄▃▄▆▃▆▃█▆▁▂▃▇▄▆▂
classifier train loss,▄▄▆▄▅▅▆▅▄▄▃▇▆▄▃▃▄▃▁█▆▄▃▇▅▄▅▅▅▅▂▃▃▅▁▄▁▆▄▇
loss,█▅▃▃▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
classifier test accuracy,0.57269
classifier train loss,0.01287
loss,63.65488
