In [78]:
%load_ext autoreload
%autoreload 2

In [1]:
import os

import numpy as np
from pathlib import Path

import torch

from torch_geometric.nn import knn_graph
from torch_geometric import transforms
from torch_geometric.data import Data, DataLoader

from geometric_vector_perceptron import GVP_Network

In [20]:
print(torch.__version__)

1.8.1


In [22]:
!pwd

/home/ysk2a15/mydocuments/EEE_Y4/COMP6248/Geometric-Vector-Perceptron


In [2]:
data_path = Path(os.getcwd()+"/gvp/data/synthetic")

cnn = torch.from_numpy(np.load(data_path/"cnn.npy"))
synthetic = torch.from_numpy(np.load(data_path/"synthetic.npy"))
with np.load(data_path/"answers.npz") as data:
    # off_center = torch.from_numpy(data["off_center"])
    perimeter = torch.from_numpy(data["perimeter"])
off_center = torch.from_numpy(np.load(data_path/"OCR.npy"))

In [27]:
synthetic.shape

torch.Size([20000, 2, 100, 3])

In [34]:
max(off_center)

tensor(8.5405)

In [3]:
def diff_matrix(vectors):
    b, _, d = vectors.shape
    # Adding new axis with None
    diff = vectors[..., None, :] - vectors[:, None, ...]
    return diff.reshape(b, -1, d)

In [4]:
vectors = torch.randn(1, 32, 3)
print(vectors.shape)

print(vectors[..., None, :].shape)
print(vectors[:, None, ...].shape)

a = vectors[..., None, :]
b = vectors[:, None, ...] 

diff = a-b 
print(diff.shape)

print(diff.reshape(diff.shape[0], diff.shape[1]*diff.shape[2], diff.shape[3]).shape)


# tmp = diff_matrix(vectors)
# print(tmp.shape)
# print(tmp)

torch.Size([1, 32, 3])
torch.Size([1, 32, 1, 3])
torch.Size([1, 1, 32, 3])
torch.Size([1, 32, 32, 3])
torch.Size([1, 1024, 3])


In [7]:
h = max(1024, 256)
h

1024

## Check dims of each layer of GVP model

In [10]:
from torch import nn, einsum

In [16]:
dim_vectors_in = 1024
dim_vectors_out = 256
dim_h = h

Wh = nn.Parameter(torch.randn(dim_vectors_in, dim_h))
Wu = nn.Parameter(torch.randn(dim_h, dim_vectors_out))

Wu

Parameter containing:
tensor([[ 0.7505, -0.5612,  0.4169,  ..., -1.2299, -1.6742,  1.1036],
        [-2.7312,  1.3213,  0.9751,  ...,  1.0150, -0.3535, -1.4883],
        [-0.6762,  0.7622, -0.4388,  ...,  0.5518, -0.8311,  0.8096],
        ...,
        [-1.0033,  1.2610, -0.6265,  ...,  1.0308, -0.4087, -0.9085],
        [-1.6114, -0.0151, -0.1247,  ...,  0.2509, -0.0923, -1.4548],
        [-0.9851, -1.1293, -0.2806,  ..., -1.0230,  0.8629, -0.2491]],
       requires_grad=True)

In [20]:
feats = torch.randn(1, 512)

b, n, _, v, c  = *feats.shape, *vectors.shape

print(b, n, _, v, c )

1 512 1 32 3


In [33]:
diff_matrix(vectors).shape

torch.Size([1, 1024, 3])

In [40]:
# this is equivalent to torch.unsqueeze(a.mm(b), 0)
# keep first dim 

Vh = einsum('b v c, v h -> b h c', diff_matrix(vectors), Wh)
Vu = einsum('b h c, h u -> b u c', Vh, Wu)

Vh.shape
Vu.shape

torch.Size([1, 256, 3])

In [46]:
sh = torch.norm(Vh, p = 2, dim = -1)
vu = torch.norm(Vu, p = 2, dim = -1, keepdim = True)

sh.shape
vu.shape

torch.Size([1, 256, 1])

In [48]:
s = torch.cat((feats, sh), dim = 1)

s.shape

torch.Size([1, 1536])

In [55]:
dim_feats_in = 512
dim_feats_out = 512
feats_activation = vectors_activation = nn.Sigmoid()

to_feats_out = nn.Sequential(
            nn.Linear(dim_h + dim_feats_in, dim_feats_out),
            feats_activation
        )

In [53]:
feats_out = to_feats_out(s)
feats_out.shape

torch.Size([1, 512])

In [57]:
vectors_out = vectors_activation(vu) * Vu

vectors_out.shape

torch.Size([1, 256, 3])

## Check for Equivariance and Invariance

In [58]:
def random_rotation():
    # Compute QR decomposition, q=orthogonal matrix
    # transformation described as multiplying V (vectors) with unitary matrix [@paper]
    # unitary matirx = orthogonal matrix (3,3)
    q, r = torch.qr(torch.randn(3, 3))
    return q

In [61]:
from geometric_vector_perceptron import GVP

R = random_rotation()

model = GVP(
        dim_vectors_in = 1024,
        dim_feats_in = 512,
        dim_vectors_out = 256,
        dim_feats_out = 512
    )


feats_out, vectors_out = model( (feats, diff_matrix(vectors)) )
feats_out_r, vectors_out_r = model( (feats, diff_matrix(vectors @ R)) )


### Features (scalar) invariance w.r.t. rotations and reflections

In [67]:
feats_out[0]

tensor([1.0000e+00, 1.0000e+00, 2.5277e-12, 8.6966e-13, 8.0424e-03, 1.0000e+00,
        7.3293e-05, 8.4491e-26, 5.4068e-15, 1.0000e+00, 5.5394e-18, 1.0000e+00,
        1.0000e+00, 9.9999e-01, 1.0000e+00, 8.3762e-11, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 4.9738e-06, 1.0000e+00, 1.0000e+00, 1.0000e+00, 9.9440e-01,
        1.8495e-01, 1.0000e+00, 7.1932e-08, 9.6137e-13, 7.0642e-22, 4.5547e-14,
        4.3304e-05, 1.0000e+00, 1.0000e+00, 4.4329e-26, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 9.9999e-01, 1.0000e+00, 3.8878e-07, 2.0826e-11, 8.4437e-01,
        1.0000e+00, 1.0000e+00, 9.9995e-01, 1.0702e-21, 7.1420e-05, 9.4205e-18,
        2.8663e-08, 3.3247e-34, 4.2385e-02, 9.8943e-01, 2.3855e-07, 3.3991e-13,
        1.0000e+00, 9.2004e-01, 1.8622e-10, 3.9961e-02, 9.6349e-02, 2.6811e-08,
        1.2512e-12, 3.4125e-23, 8.2422e-25, 2.6073e-11, 9.9969e-01, 1.0000e+00,
        1.0000e+00, 4.5652e-24, 1.0000e+00, 9.9999e-01, 1.0000e+00, 8.9313e-01,
        9.9779e-01, 9.7619e-01, 9.9999e-

In [64]:
feats_out_r[0]

tensor([1.0000e+00, 1.0000e+00, 2.5277e-12, 8.6966e-13, 8.0425e-03, 1.0000e+00,
        7.3292e-05, 8.4490e-26, 5.4068e-15, 1.0000e+00, 5.5394e-18, 1.0000e+00,
        1.0000e+00, 9.9999e-01, 1.0000e+00, 8.3762e-11, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 4.9738e-06, 1.0000e+00, 1.0000e+00, 1.0000e+00, 9.9440e-01,
        1.8495e-01, 1.0000e+00, 7.1932e-08, 9.6137e-13, 7.0641e-22, 4.5547e-14,
        4.3304e-05, 1.0000e+00, 1.0000e+00, 4.4329e-26, 1.0000e+00, 1.0000e+00,
        1.0000e+00, 9.9999e-01, 1.0000e+00, 3.8878e-07, 2.0826e-11, 8.4437e-01,
        1.0000e+00, 1.0000e+00, 9.9995e-01, 1.0702e-21, 7.1420e-05, 9.4206e-18,
        2.8664e-08, 3.3248e-34, 4.2384e-02, 9.8943e-01, 2.3855e-07, 3.3991e-13,
        1.0000e+00, 9.2004e-01, 1.8622e-10, 3.9961e-02, 9.6350e-02, 2.6811e-08,
        1.2512e-12, 3.4125e-23, 8.2422e-25, 2.6073e-11, 9.9969e-01, 1.0000e+00,
        1.0000e+00, 4.5651e-24, 1.0000e+00, 9.9999e-01, 1.0000e+00, 8.9313e-01,
        9.9779e-01, 9.7619e-01, 9.9999e-

In [71]:
torch.eq(torch.round(feats_out[0]), torch.round(feats_out_r[0]))

tensor([True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, True, True, True, True, True, True, True,
        True, True, True, True, True, Tr

In [82]:
torch.all(torch.eq(torch.round(feats_out[0]), torch.round(feats_out_r[0])))

tensor(True)

In [75]:
(feats_out - feats_out_r).max()

tensor(2.8610e-06, grad_fn=<MaxBackward1>)

### Vectors equivariance w.r.t. rotations and reflections

In [76]:
u, s, v = torch.svd(vectors_out)

s

tensor([[25183.6230, 21765.9180, 18937.5078]], grad_fn=<SvdHelperBackward>)

In [77]:
u, s, v = torch.svd(vectors_out_r)

s

tensor([[25183.6211, 21765.9180, 18937.5098]], grad_fn=<SvdHelperBackward>)

In [80]:
vectors_out[0][:10]

tensor([[  211.3972,   507.6278,  1003.5482],
        [ -569.8422,  -625.5551,   349.9399],
        [ 2238.3110, -1697.3890,  1406.7788],
        [ -544.5129, -1574.9205,  -700.5328],
        [-2667.7134,  -123.2306, -1287.6260],
        [ 1019.3002,  -937.6866,  2115.3357],
        [ -258.0677,  1336.5889,  1083.5842],
        [ -917.5706, -1021.5850,  1085.4001],
        [ -988.6802, -1194.0969, -2335.5066],
        [  880.2687,  1403.6772,  1974.6440]], grad_fn=<SliceBackward>)

In [81]:
vectors_out_r[0][:10]

tensor([[ 6.1974e+02, -8.3118e+02, -4.8430e+02],
        [ 7.0712e+02,  3.7021e+02,  4.4881e+02],
        [ 2.4127e+02, -2.2940e+03,  2.1329e+03],
        [-1.4718e+00,  1.1169e+03,  1.4213e+03],
        [ 4.1144e+02,  2.8962e+03, -4.8228e+02],
        [ 1.3501e+03, -1.8540e+03,  1.0643e+03],
        [ 7.9357e+02, -6.4784e+02, -1.4063e+03],
        [ 1.5698e+03,  3.0181e+02,  7.1300e+02],
        [-1.1706e+03,  2.3354e+03,  1.0166e+03],
        [ 8.9600e+02, -2.0806e+03, -1.2300e+03]], grad_fn=<SliceBackward>)

- different values for with and without rotations --> equivariance 
- equivariance = if input changes, output changes 
- f(g(x)) = g(f(x))