# **This is an example notebook that shows how to integrate RelEdgePool (it's components) in a 3D deep learning pipeline.**
##### RelEdgePool is integrated in your model, so the rest of your pipeline remains the same.
---

### Imports
---

In [None]:
import torch
import torch.nn as nn
from pytorch3d.io import load_objs_as_meshes
from torch.utils.data import DataLoader

'''
    The GraphConvolver class has logic for performing graph convolution on batched meshes having different numbers of vertices and edges.
    Similarly, the PoolingExecutor class has logic for performing pooling with RelEdgePool on batched meshes having different number of vertices and edges
'''
from utils.convolution_executor import GraphConvolver
from utils.pooling_executor import PoolingExecutor
from utils.experimentation_datasets_dataloader import ExperimentationDatasets

### Data loading
---

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def collate_meshes(batch):
    """
    This function takes the mesh paths generated by the ExperimentalDatasets class and batched by the Dataloader.
    It loads meshes from the batched mesh paths, and batches their vertices and edges.
    """
    mesh_paths, labels = zip(*batch)

    # Load meshes from paths
    meshes = load_objs_as_meshes(list(mesh_paths), load_textures=False, device=device)

    # Get padded vertices and individual edges
    vertices_batch = meshes.verts_padded()
    edges_batch = [mesh.edges_packed() for mesh in meshes]

    # Convert labels to tensor
    labels_tensor = torch.tensor(labels, dtype=torch.long, device=device)

    return vertices_batch, edges_batch, labels_tensor

In [None]:
dataset_root_directory = "../../datasets/shrec" # choose shrec or cubes
try:
    train_dataset = ExperimentationDatasets(dataset_root_directory, split='train')
    test_dataset = ExperimentationDatasets(dataset_root_directory, split='test')
except FileNotFoundError:
    print("Dataset not found. Make sure to execute the relevant script in the datasets folder.")

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_meshes)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, collate_fn=collate_meshes)

### Model Architecture
---

In [None]:
"""
    A dummy graph convolutional neural network that has RelEdgePool integrated in it.
"""
class DummyGCNN(nn.Module):
    def __init__(self):
        super().__init__()
        # Graph convolution layers
        self.gc1 = GraphConvolver(in_channels=3, out_channels=16)
        self.gc2 = GraphConvolver(in_channels=16, out_channels=32)
        self.gc3 = GraphConvolver(in_channels=32, out_channels=64)
        self.gc4 = GraphConvolver(in_channels=64, out_channels=32)
        self.gc5 = GraphConvolver(in_channels=32, out_channels=16)

        self.ln1 = nn.LayerNorm(normalized_shape=16)
        self.ln2 = nn.LayerNorm(normalized_shape=32)
        self.ln3 = nn.LayerNorm(normalized_shape=64)
        self.ln4 = nn.LayerNorm(normalized_shape=32)
        self.ln5 = nn.LayerNorm(normalized_shape=16)

        self.dropout = nn.Dropout(p=0.3)
        self.l1 = nn.Linear(in_features=16, out_features=32)
        self.l2 = nn.Linear(in_features=32, out_features=30)  # Adjust out_features based on the number of classes in your dataset

    def forward(self, vertices_batch, edges_batch):
        # First conv block
        vertices_batch = self.gc1.convolve(vertices_batch, edges_batch)
        vertices_batch = self.ln1(vertices_batch)
        vertices_batch = nn.functional.leaky_relu(vertices_batch, 0.2)
        vertices_batch = self.dropout(vertices_batch)
        vertices_batch, edges_batch = PoolingExecutor(vertices_batch, edges_batch).pool()  # first pooling

        # Second conv block
        vertices_batch = self.gc2.convolve(vertices_batch, edges_batch)
        vertices_batch = self.ln2(vertices_batch)
        vertices_batch = nn.functional.leaky_relu(vertices_batch, 0.2)
        vertices_batch = self.dropout(vertices_batch)
        vertices_batch, edges_batch = PoolingExecutor(vertices_batch, edges_batch).pool()  # second pooling

        # Third conv block
        vertices_batch = self.gc3.convolve(vertices_batch, edges_batch)
        vertices_batch = self.ln3(vertices_batch)
        vertices_batch = nn.functional.leaky_relu(vertices_batch, 0.2)
        vertices_batch = self.dropout(vertices_batch)
        vertices_batch, edges_batch = PoolingExecutor(vertices_batch, edges_batch).pool()  # third pooling

        # Fourth conv block
        vertices_batch = self.gc4.convolve(vertices_batch, edges_batch)
        vertices_batch = self.ln4(vertices_batch)
        vertices_batch = nn.functional.leaky_relu(vertices_batch, 0.2)
        vertices_batch = self.dropout(vertices_batch)
        vertices_batch, edges_batch = PoolingExecutor(vertices_batch, edges_batch).pool()  # fourth pooling

        # Fifth conv block
        vertices_batch = self.gc5.convolve(vertices_batch, edges_batch)
        vertices_batch = self.ln5(vertices_batch)
        vertices_batch = nn.functional.leaky_relu(vertices_batch, 0.2)
        vertices_batch = self.dropout(vertices_batch)
        vertices_batch, edges_batch = PoolingExecutor(vertices_batch, edges_batch).pool()  # fifth pooling

        vertices_batch = torch.mean(vertices_batch, dim=1)  # Less than five vertices left, taking their mean just in case there are some orphan vertices. You can continue pooling with RelEdgePool.
        vertices_batch = nn.functional.leaky_relu(self.l1(vertices_batch), 0.2)
        vertices_batch = self.dropout(vertices_batch)
        vertices_batch = self.l2(vertices_batch)

        return vertices_batch

### Training
---

In [None]:
# Training continues as usual.
gcnn = DummyGCNN()
gcnn.to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(gcnn.parameters(), lr=0.1)

best_loss = float('inf')
patience_counter = 0
patience_limit = 100
epochs = 400

for epoch in range(epochs):
    for vertices_batch, edges_batch, train_labels in train_loader:
        gcnn.train()
        optimizer.zero_grad()
        preds = gcnn.forward(vertices_batch.clone(), edges_batch.copy())
        loss = loss_fn(preds, train_labels)
        print(f"Epoch {epoch}: loss={loss.item():.4f}")

        loss.backward()
        torch.nn.utils.clip_grad_norm_(gcnn.parameters(), max_norm=1.0)
        optimizer.step()

        if loss < best_loss:
            best_loss = loss
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter >= patience_limit:
            print(f"Early stopping triggered after {epoch} epochs")
            break