In [57]:
from model_util import ModelSpatialSteer, ModelSpatial, ClassicalCNN, ModelVoxel
import torch
from torch.utils.data import Dataset, DataLoader

In [58]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [59]:
# Arguments that the networks require. Some of them are only for logging the information.
class args:
    num_shells = 1
    ray_len = None
    num_rays = 5
    samples_per_ray = 2
    b_size = 10
    watson_param = 10
    num_epoch = 10
    lr = 0.001
    model_capacity = 'smalls'
    data_aug = False
    iter = 100
    pooling = 'max'
    bias = True
    spatial_kernel_size = [3, 3, 3]
    lin_bias = True
    lin_bn = True
    spatial_bias = True

# Synthetic dataset for the SE(3) Goup CNN and the T<sup>3</sup> x SO(3) Group CNN

In [60]:
class SE3Dataset(Dataset):
    def __init__(self):
        self.data = torch.rand(100, 12, 11, 1, 7, 7, 7)  # A grid of interpolated spherical functions
        self.labels = torch.randint(4, (100,))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):

        return self.data[idx], self.labels[idx]

# Synthetic dataset for the SO(3) Group CNN

In [61]:
class SO3Dataset(Dataset):
    def __init__(self):
        self.data = torch.rand(100, 12, 11, 1, 1, 1, 1)  # Individual spherical functions from each voxel
        self.labels = torch.randint(4, (100,))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):

        return self.data[idx], self.labels[idx]

# Synthetic dataset for the Classical CNN

In [62]:
class ClassicalDataset(Dataset):
    def __init__(self):
        self.data = torch.rand(100, 90, 7, 7, 7)  # A grid of voxels
        self.labels = torch.randint(4, (100,))

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):

        return self.data[idx], self.labels[idx]

In [63]:
se3_dataset = SE3Dataset()
se3_dl = DataLoader(se3_dataset, batch_size=10, shuffle=True)

In [64]:
so3_dataset = SO3Dataset()
so3_dl = DataLoader(so3_dataset, batch_size=10, shuffle=True)

In [65]:
classical_dataset = ClassicalDataset()
classical_dl = DataLoader(classical_dataset, batch_size=10, shuffle=True)

# Demo to run the SE(3) Goup CNN (Ours)

In [66]:
model_se3 = ModelSpatialSteer(args, device=device, full_group=True)

Clockwise ring, making it counterclockwise.
Clockwise ring, making it counterclockwise.
Clockwise ring, making it counterclockwise.
Clockwise ring, making it counterclockwise.
Clockwise ring, making it counterclockwise.
Clockwise ring, making it counterclockwise.


In [67]:
model_se3 = model_se3.to(device)
for data, label in se3_dl:
    data = data.to(device)
    out = model_se3(data)
    print(out.shape)

torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])


# Demo to run the T<sup>3</sup> x SO(3) Group CNN (OursDecoupled)

In [68]:
model_decoupled = ModelSpatial(args)

In [69]:
model_decoupled = model_decoupled.to(device)
for data, label in se3_dl:
    data = data.to(device)
    out = model_decoupled(data)
    print(out.shape)

torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])


# Demo to run the SO(3) Group CNN (Baseline)

In [70]:
model_so3 = ModelVoxel(args)

In [71]:
model_so3 = model_so3.to(device)
for data, label in so3_dl:
    data = data.to(device)
    out = model_so3(data)
    print(out.shape)

torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])


# Demo to run the Classical CNN

In [72]:
model_classical = ClassicalCNN(args)

In [73]:
model_classical = model_classical.to(device)
for data, label in classical_dl:
    data = data.to(device)
    out = model_classical(data)
    print(out.shape)

torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
torch.Size([10, 4])
