### Here we just check if the Encoder Network can correctly classify the 8 Tetris Shapes.


In [1]:
from functools import partial

In [2]:
import torch
from giae.se3.data import DataModule, TetrisDatasetPyG
from giae.se3.modules import Encoder

In [3]:
encoder = Encoder(hidden_dim=32, emb_dim=2, num_layers=5, layer_norm=False)

In [4]:
dataset = partial(TetrisDatasetPyG,
                  rotate=False, num_elements=10000, noise_level=0.0, translation_level=0.0)

In [5]:
datamodule = DataModule(
    dataset=dataset,
    train_samples=10000,
    batch_size=64,
    num_workers=0,
    num_eval_samples=200,
)

In [6]:
train_loader = datamodule.train_dataloader(shuffle=False)

In [7]:
device = "cuda:8"

In [8]:
data = next(iter(train_loader))

In [9]:
data=data.to(device)

In [10]:
encoder = encoder.to(device)

In [11]:
lin = torch.nn.Linear(2, 8, device=device)

In [12]:
loss_fnc = torch.nn.CrossEntropyLoss()

In [13]:
params = list(encoder.parameters()) + list(lin.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

In [14]:
nepochs = 20

In [15]:
for i in range(nepochs):
    for j, data in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        shape_embed, point_embed, rot, transl_out, vout = encoder(pos=data.pos, batch=data.batch,
                                                          batch_num_nodes=torch.bincount(data.batch), 
                                                          edge_index=data.edge_index, use_fc=True)
        y_pred = lin(shape_embed)
        loss = loss_fnc(y_pred, data.label.argmax(-1))
        loss.backward()
        optimizer.step()
        
        if j % 50 == 0:
            acc = sum(data.label.argmax(-1) == y_pred.argmax(-1)) / len(y_pred)
            print(f" Epoch: {i}/{nepochs}, Step {j}/{len(train_loader)}, Loss: {loss.item():.4f}, Acc: {acc.item():.4f}")

 Epoch: 0/20, Step 0/157, Loss: 2.1655, Acc: 0.1875
 Epoch: 0/20, Step 50/157, Loss: 0.6951, Acc: 0.5781
 Epoch: 0/20, Step 100/157, Loss: 0.2515, Acc: 0.8750
 Epoch: 0/20, Step 150/157, Loss: 0.4860, Acc: 0.7812
 Epoch: 1/20, Step 0/157, Loss: 0.4103, Acc: 0.7344
 Epoch: 1/20, Step 50/157, Loss: 0.1083, Acc: 0.9531
 Epoch: 1/20, Step 100/157, Loss: 1.8869, Acc: 0.2656
 Epoch: 1/20, Step 150/157, Loss: 1.3040, Acc: 0.5000
 Epoch: 2/20, Step 0/157, Loss: 1.2980, Acc: 0.6406
 Epoch: 2/20, Step 50/157, Loss: 0.6851, Acc: 0.7031
 Epoch: 2/20, Step 100/157, Loss: 0.1898, Acc: 0.8906
 Epoch: 2/20, Step 150/157, Loss: 0.0023, Acc: 1.0000
 Epoch: 3/20, Step 0/157, Loss: 0.0014, Acc: 1.0000
 Epoch: 3/20, Step 50/157, Loss: 0.0002, Acc: 1.0000
 Epoch: 3/20, Step 100/157, Loss: 0.0001, Acc: 1.0000
 Epoch: 3/20, Step 150/157, Loss: 0.0001, Acc: 1.0000
 Epoch: 4/20, Step 0/157, Loss: 0.0001, Acc: 1.0000
 Epoch: 4/20, Step 50/157, Loss: 0.0001, Acc: 1.0000
 Epoch: 4/20, Step 100/157, Loss: 0.0000, A

### Let's check if we create "new" shapes by adding some gaussian noise onto each point
We also rotate and translate the point-cloud, as the `shape_embed` is invariant, it is not affected by it, but only the `vout` tensor.

In [16]:
dataset = partial(TetrisDatasetPyG,
                  rotate=True, num_elements=10000, noise_level=0.01, translation_level=5.0)

datamodule = DataModule(
    dataset=dataset,
    train_samples=10000,
    batch_size=64,
    num_workers=0,
    num_eval_samples=200,
)
train_loader = datamodule.train_dataloader(shuffle=False)
encoder = Encoder(hidden_dim=32, emb_dim=2, num_layers=5, layer_norm=False).to(device)
lin = torch.nn.Linear(2, 8, device=device)
params = list(encoder.parameters()) + list(lin.parameters())
optimizer = torch.optim.Adam(params, lr=0.001)

In [17]:
for i in range(nepochs):
    for j, data in enumerate(train_loader):
        optimizer.zero_grad()
        data = data.to(device)
        shape_embed, point_embed, rot, transl_out, vout = encoder(pos=data.pos, batch=data.batch,
                                                          batch_num_nodes=torch.bincount(data.batch), 
                                                          edge_index=data.edge_index, use_fc=True)
        y_pred = lin(shape_embed)
        loss = loss_fnc(y_pred, data.label.argmax(-1))
        loss.backward()
        optimizer.step()
        
        if j % 50 == 0:
            acc = sum(data.label.argmax(-1) == y_pred.argmax(-1)) / len(y_pred)
            print(f" Epoch: {i}/{nepochs}, Step {j}/{len(train_loader)}, Loss: {loss.item():.4f}, Acc: {acc.item():.4f}")

 Epoch: 0/20, Step 0/157, Loss: 2.1845, Acc: 0.1406
 Epoch: 0/20, Step 50/157, Loss: 0.4487, Acc: 0.7812
 Epoch: 0/20, Step 100/157, Loss: 0.2378, Acc: 0.8594
 Epoch: 0/20, Step 150/157, Loss: 1.5658, Acc: 0.4531
 Epoch: 1/20, Step 0/157, Loss: 1.8023, Acc: 0.3281
 Epoch: 1/20, Step 50/157, Loss: 0.1199, Acc: 0.9375
 Epoch: 1/20, Step 100/157, Loss: 1.6783, Acc: 0.3281
 Epoch: 1/20, Step 150/157, Loss: 0.3064, Acc: 0.8125
 Epoch: 2/20, Step 0/157, Loss: 0.2008, Acc: 0.9375
 Epoch: 2/20, Step 50/157, Loss: 1.9656, Acc: 0.2500
 Epoch: 2/20, Step 100/157, Loss: 1.5561, Acc: 0.4219
 Epoch: 2/20, Step 150/157, Loss: 0.9270, Acc: 0.5938
 Epoch: 3/20, Step 0/157, Loss: 0.7318, Acc: 0.6250
 Epoch: 3/20, Step 50/157, Loss: 1.3451, Acc: 0.4844
 Epoch: 3/20, Step 100/157, Loss: 0.5641, Acc: 0.9219
 Epoch: 3/20, Step 150/157, Loss: 0.0512, Acc: 1.0000
 Epoch: 4/20, Step 0/157, Loss: 0.0189, Acc: 1.0000
 Epoch: 4/20, Step 50/157, Loss: 0.0008, Acc: 1.0000
 Epoch: 4/20, Step 100/157, Loss: 0.0006, A