Geometric Model Test Implementation
==============

# Questions:

I managed to get the GCNN working in the latest code, but it's still not performing as well as a regular CNN. Can you take a brief look, if there are any major errors in my configuration?

Some quick questions:
- I made a custom bug fix in the GSpatialMaxpool to get the shaping correct and a temperary fix with padding size, I hope this is correct?
- I have VRAM problems in a 24GB GPU.  I noticed that the GCNN package doesnt support FP16 (type mismatch errors). I think I can implement support, but before I do, is there a technical reason GCNN wouldn't work with FP16 / BF16? Such as not enough precision in FP16?
- I imagine in GCNN the default kernel size is 5,  because kernels smaller than 5 has aliasing issues and doesn't record all features across all angles properly?
- I remember you mentioning that GConv layers doesn't need as many channels as as equivalent regular Conv, because of the extra operations it does and the more vram it uses? and you can use 25% channels for a fair comparison because of default group kernel of 4?
  
  
  

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import lightning as pl
from torch import optim
import gconv.gnn as gnn
from torch import Tensor
import numpy as np

from support import * 

# GMaxSpatialPool3D Bug Fix

The original GMaxSpatialPool3D provided seem to be bugged in that it doesn't reshape the spatial dimensions correctly. (https://github.com/ThijsKuipers1995/gconv/blob/main/gconv/gnn/modules/pooling.py)

In [2]:
test = Tensor(np.ones((5, 3, 128, 128, 128)))
x, h = gnn.GLiftingConvSE3(in_channels = 3, out_channels = 16, kernel_size = 5, padding = 2)(test)
out, h = gnn.GMaxSpatialPool3d(2)(x, h)

RuntimeError: shape '[5, 16, 4]' is invalid for input of size 83886080

So I implemented the following as a fix. Not sure if this is the correct way of doing it:

In [3]:
class GMaxSpatialPool3d_fixed(nn.MaxPool3d):
    """
    Performs spatial max pooling on 3d spatial inputs.
    """

    def forward(self, x: Tensor, H: Tensor) -> Tensor:
               
        y = super().forward(x.flatten(1, 2))
        D2, H2, W2 = y.shape[-3:]
        
        return y.view(*x.shape[:3], D2, H2, W2), H


test = Tensor(np.ones((5, 3, 128, 128, 128)))
x, h = gnn.GLiftingConvSE3(in_channels = 3, out_channels = 16, kernel_size = 5, padding = 2)(test)
out, h = GMaxSpatialPool3d_fixed(2)(x, h)

print(x.shape)
print(out.shape)

torch.Size([5, 16, 4, 128, 128, 128])
torch.Size([5, 16, 4, 64, 64, 64])


# Padding Inconsistency in GSeparableConvSE3

The padding for the GLiftingConvSE3 seems to be working correctly: with a kernel size of 5, padding of 0 reduces size by 4 and a padding of 2 keeps the same shape. This is consistent with the regular torch CONV layers, where padding = (kernel_size - 1)//2. 

However, the padding parameter for the GSeparableConvSE3 seemed to not working as expected. 

With a kernel size of 5, a padding of 0 reduces size by 4 like before, but a padding of 2 increases the shape by 4:

In [4]:
test = Tensor(np.ones((5, 3, 128, 128, 128)))

x, h = gnn.GLiftingConvSE3(in_channels = 3, out_channels = 16, kernel_size = 5, padding = 2)(test)
x1, h2 = gnn.GSeparableConvSE3(16, 32, kernel_size = 5, padding = 0, stride=1)(x, h)


print(test.shape)
print(x.shape)
print(x1.shape)

print()
print()

test = Tensor(np.ones((5, 3, 128, 128, 128)))

x, h = gnn.GLiftingConvSE3(in_channels = 3, out_channels = 16, kernel_size = 5, padding = 2)(test)
x1, h2 = gnn.GSeparableConvSE3(16, 32, kernel_size = 5, padding = 2, stride=1)(x, h)


print(test.shape)
print(x.shape)
print(x1.shape)

torch.Size([5, 3, 128, 128, 128])
torch.Size([5, 16, 4, 128, 128, 128])
torch.Size([5, 32, 4, 124, 124, 124])


torch.Size([5, 3, 128, 128, 128])
torch.Size([5, 16, 4, 128, 128, 128])
torch.Size([5, 32, 4, 132, 132, 132])


This means the GSeparableConvSE3 layer does some internal handling of padding, and I was able to get the shape to match by using a custom padding formula: (padding = (kernel_size - 1)//2 - 1).

In this case, to keep the same shape with kernel size = 5, the padding needs to be set to 1.

In [5]:
test = Tensor(np.ones((5, 3, 128, 128, 128)))

x, h = gnn.GLiftingConvSE3(in_channels = 3, out_channels = 16, kernel_size = 5, padding = 2)(test)
x1, h2 = gnn.GSeparableConvSE3(16, 32, kernel_size = 5, padding = 1, stride=1)(x, h)


print(test.shape)
print(x.shape)
print(x1.shape)

torch.Size([5, 3, 128, 128, 128])
torch.Size([5, 16, 4, 128, 128, 128])
torch.Size([5, 32, 4, 128, 128, 128])


# Testing with a basic regression network

Below is a basic GCNN regressor with Lifting, GConv, GSpatialMaxPool, GAvgGlobalPool, and a regular regression head with standard CNN layers. 

It doesn't regress quite as well a regular CNN. I hope there isn't a major problem in the configuration / structure of the model.

Training progress & results can be found in attached HTML: https://github.com/0-CWANG-0/GCNNtest/blob/main/GCNN_regression.pdf

In [6]:
class geometric_conv_block(nn.Module):
    def __init__(self, in_channels, 
                       mid_channels, 
                       out_channels, 
                       kernel_size = 5,
                       stride= 1):
        
        super().__init__()
        padding = (kernel_size - 1) // 2 - 1
        
        self.gconv1_internal = gnn.GSeparableConvSE3(in_channels, 
                                                     mid_channels, 
                                                     kernel_size = kernel_size,
                                                     padding = padding,
                                                     stride=1)
        self.normalization_1 = gnn.GBatchNorm3d(mid_channels)
        self.activation_1 = nn.ReLU(inplace=True)
        
        self.gconv2_internal = gnn.GSeparableConvSE3(mid_channels, 
                                                     out_channels, 
                                                     kernel_size = kernel_size,
                                                     padding = padding,
                                                     stride=stride)
        self.normalization_2 = gnn.GBatchNorm3d(out_channels)
        self.activation_2 = nn.ReLU(inplace=True)
        

    def forward(self, x, H):
        
        x, H = self.gconv1_internal(x, H)  
        #x, H = self.normalization_1(x, H)         
        x = self.activation_1(x)
        
        x, H = self.gconv2_internal(x, H)
        #x, H = self.normalization_2(x, H)
        x = self.activation_2(x)
        
        return x, H

In [7]:
class geometric_basic_regressor_3D(pl.LightningModule):
    def __init__(self, 
                 loss_func = "mse",  
                 learning_rate = 0.0001,
                 input_channels = 1,
                 output_channels = 1,
                 mlp_hidden = 256,
                 dropout_p = 0.0):
        
        super().__init__()

        self.save_hyperparameters()

        # lifting to SE3
        self.lifting_layer = gnn.GLiftingConvSE3(in_channels = input_channels,
                                                 out_channels = 16,
                                                 kernel_size = 5,
                                                 padding = 2)

        # Define the encoding layers
        self.encoder1 = geometric_conv_block(16, 16, 32, kernel_size = 5)
        self.maxpooler1 = GMaxSpatialPool3d_fixed(2)
        
        self.encoder2 = geometric_conv_block(32, 32, 64, kernel_size = 5)
        self.maxpooler2 = GMaxSpatialPool3d_fixed(2)
        
        self.encoder3 = geometric_conv_block(64, 64, 128, kernel_size = 5)
        self.maxpooler3 = GMaxSpatialPool3d_fixed(2)
        
        self.encoder4 = geometric_conv_block(128, 128, 256, kernel_size = 5)

    
        # Define the regression layers
        self.global_pool = gnn.GAvgGlobalPool()
        self.regressor = nn.Sequential(nn.Flatten(),                      
                                       nn.Linear(256, mlp_hidden),
                                       nn.LayerNorm(mlp_hidden),
                                       nn.ReLU(inplace=True),
                                       nn.Dropout(dropout_p) if dropout_p > 0 else nn.Identity(),
                                       nn.Linear(mlp_hidden, output_channels)
                                       )
        nn.init.uniform_(self.regressor[-1].weight, -1e-3, 1e-3)
        nn.init.zeros_(self.regressor[-1].bias)

        self.loss_func = loss_func.lower()
        if self.loss_func not in {"mse", "l1", "smoothl1"}:
            raise ValueError("Invalid loss type. Use 'mse', 'l1', or 'smoothl1'.")
        self.learning_rate = learning_rate

    def forward(self, x):
        # lifting SE3
        x, H = self.lifting_layer(x)
        
        # Encodor
        x, H = self.encoder1(x, H)
        x, H = self.maxpooler1(x, H)
        x, H = self.encoder2(x, H)
        x, H = self.maxpooler2(x, H)
        x, H = self.encoder3(x, H)
        x, H = self.maxpooler3(x, H)
        x, H = self.encoder4(x, H)
        
        # regressor
        features = self.global_pool(x, H)
        out  = self.regressor(features.float())     

        return out


    def training_step(self, batch, batch_idx):
        x, y = batch

        # corrects shape
        if y.ndim == 1:
            y = y.unsqueeze(-1)
        y = y.float()

        preds = self(x)
        loss  = self.compute_loss(preds, y)
    
        #mae = torch.mean(torch.abs(preds - y))
        
        # MAPE (%)
        denom = y.abs().clamp_min(1e-8)
        mape  = 100.0 * torch.mean(torch.abs((preds - y) / denom))
        
        self.log("train_loss", loss, prog_bar=True, on_epoch=False, on_step=True)
        self.log("train_mape", mape, prog_bar=True, on_epoch=False, on_step=True)
        return loss


    def validation_step(self, batch, batch_idx):
        x, y = batch
    
        if y.ndim == 1:
            y = y.unsqueeze(-1)
        y = y.float()

        preds = self(x)
        loss  = self.compute_loss(preds, y)

        #mae = torch.mean(torch.abs(preds - y))
        
        denom = y.abs().clamp_min(1e-8)
        mape  = 100.0 * torch.mean(torch.abs((preds - y) / denom))
        
        self.log("valid_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
        self.log("valid_mape", mape, prog_bar=True, on_epoch=True, on_step=True)
        return loss


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                         mode='min', 
                                                         patience=3, 
                                                         factor=0.2
                                                         )
        return {"optimizer": optimizer, 
                "lr_scheduler": {"scheduler": scheduler, "monitor": "valid_loss"}}


    
    def compute_loss(self, preds, targets):
        preds = preds.float()
        targets = targets.float()
        if self.loss_func == "mse":
            return nn.MSELoss()(preds, targets)
        elif self.loss_func == "l1":
            return nn.L1Loss()(preds, targets)
        elif self.loss_func == "smoothl1":
            return nn.SmoothL1Loss(beta=0.5)(preds, targets)
        else:
            raise ValueError("Invalid loss type.")

# Testing with UNet for segmentation

I then tried with a segmentation. This one I struggle with VRAM, even on a 24GB GPU, and I can only run with batch size 1 and lowering the channel sizes significantly. I might implement FP16 mixed precision support.

In [8]:
class geometric_conv_block(nn.Module):
    def __init__(self, in_channels, 
                       mid_channels, 
                       out_channels, 
                       kernel_size = 5,
                       stride= 1):
        
        super().__init__()
        padding = (kernel_size - 1) // 2 - 1
        
        self.gconv1_internal = gnn.GSeparableConvSE3(in_channels, 
                                                     mid_channels, 
                                                     kernel_size = kernel_size,
                                                     padding = padding,
                                                     stride=1)
        self.normalization_1 = gnn.GBatchNorm3d(mid_channels)
        self.activation_1 = nn.ReLU(inplace=True)
        
        self.gconv2_internal = gnn.GSeparableConvSE3(mid_channels, 
                                                     out_channels, 
                                                     kernel_size = kernel_size,
                                                     padding = padding,
                                                     stride=stride)
        self.normalization_2 = gnn.GBatchNorm3d(out_channels)
        self.activation_2 = nn.ReLU(inplace=True)
        

    def forward(self, x, H):
        
        x, H = self.gconv1_internal(x, H)  
        x, H = self.normalization_1(x, H)         
        x = self.activation_1(x)
        
        x, H = self.gconv2_internal(x, H)
        x, H = self.normalization_2(x, H)
        x = self.activation_2(x)
        
        return x, H




class geometric_upconv_block(nn.Module):


    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size = 3,
                 scale_factor = (2, 2, 2),
                 align_corners = False):
        
        
        super().__init__()

        

        padding = (kernel_size - 1) // 2 - 1

        
        self.scale_factor = scale_factor
        self.align_corners = align_corners
        

        self.gconv_internal = gnn.GSeparableConvSE3(in_channels,
                                                    out_channels,
                                                    kernel_size = kernel_size,
                                                    padding = padding)
        
        self.normalization_internal = gnn.GBatchNorm3d(out_channels)
        
        self.relu_internal = nn.ReLU(inplace = True)
        



    def forward(self, x: torch.Tensor, H: torch.Tensor):
        # x: (B, C, R, D, H, W)
        B, C, R, D, Hs, Ws = x.shape

        # Merge (C, R)
        x_cr = x.flatten(1, 2)  # (B, C*R, D, H, W)
        x_cr = F.interpolate(x_cr,
                             scale_factor = self.scale_factor,
                             mode = 'trilinear',
                             align_corners = self.align_corners)

        # Restore group dim R
        D2, H2, W2 = x_cr.shape[-3:]
        x_up = x_cr.contiguous().view(B, C, R, D2, H2, W2)

        # Equivariant conv + norm + act
        x_up, H = self.gconv_internal(x_up, H)
        x_up, H = self.normalization_internal(x_up, H)
        x_up = self.relu_internal(x_up)
        
        return x_up, H





class Geometric_UNet3D(pl.LightningModule):
    def __init__(self, 
                 learning_rate=0.0001,
                 input_channels = 2,
                 first_kernel_size = 5):
        
        super().__init__()
        self.save_hyperparameters()
        
        
        self.first_kernel_size = first_kernel_size
        self.learning_rate = learning_rate
        
        
        # initial lifting
        self.padding_first = (self.first_kernel_size - 1) // 2
        
        
        self.lifting_layer = gnn.GLiftingConvSE3(in_channels = input_channels,
                                                 out_channels = 16,
                                                 kernel_size = self.first_kernel_size,
                                                 padding = self.padding_first)
        
        # Define the 3D U-Net layers (encoder)
        self.encoder1 = geometric_conv_block(16, 16, 16, 
                                             kernel_size = self.first_kernel_size)
        self.maxpooler1 = GMaxSpatialPool3d_fixed(2)
        
        self.encoder2 = geometric_conv_block(16, 16, 32, 
                                             kernel_size = 5)
        self.maxpooler2 = GMaxSpatialPool3d_fixed(2)
        
        self.encoder3 = geometric_conv_block(32, 32, 32, 
                                             kernel_size = 5)
        self.maxpooler3 = GMaxSpatialPool3d_fixed(2)
        
        self.encoder4 = geometric_conv_block(32, 64, 64, 
                                             kernel_size = 5)
        
        self.upconv1 = geometric_upconv_block(64, 
                                              64,
                                              kernel_size = 5,
                                              scale_factor = 2,
                                              align_corners = False)

        
        # Define the 3D U-Net layers (decoder)
        self.decoder1 = geometric_conv_block(64 + 32, 32, 32, 
                                             kernel_size = 5)
        self.upconv2 = geometric_upconv_block(32, 
                                              32,
                                              kernel_size = 5,
                                              scale_factor = 2,
                                              align_corners = False)
            
            
        self.decoder2 = geometric_conv_block(32 + 32, 16, 16, 
                                             kernel_size = 5)
        self.upconv3 = geometric_upconv_block(16, 
                                              16,
                                              kernel_size = 5,
                                              scale_factor = 2,
                                              align_corners = False)
            
        self.decoder3 = geometric_conv_block(16+16, 16, 16, 
                                             kernel_size = 5)

        
        
        self.group_pool = gnn.GAvgGroupPool() 

        # segmentation classification
        self.final_conv = nn.Conv3d(16, 1, kernel_size = 1)

        # sigmoid output for segmentation
        self.final_conv_activation = nn.Sigmoid()
            
        
        # applies HE normal initialization
        init.kaiming_normal_(self.final_conv.weight, mode='fan_in', nonlinearity='relu')
        if self.final_conv.bias is not None:
            init.zeros_(self.final_conv.bias)
        
    



    def forward(self, x):

        # lifting layer
        x, H = self.lifting_layer(x)                 

        # Encoding path
        enc1, H = self.encoder1(x, H)               
        mpl1, H = self.maxpooler1(enc1, H)          
        
        enc2, H = self.encoder2(mpl1, H)             
        mpl2, H = self.maxpooler2(enc2, H)
        
        enc3, H = self.encoder3(mpl2, H)           
        mpl3, H = self.maxpooler3(enc3, H)
        
        enc4, H = self.encoder4(mpl3, H)           
        
        # Decoding path
        upc1, H = self.upconv1(enc4, H)              
        conc1 = torch.cat([enc3, upc1], dim=1)       
        dec1, H = self.decoder1(conc1, H)            
        
        upc2, H = self.upconv2(dec1, H)              
        conc2 = torch.cat([enc2, upc2], dim=1)       
        dec2, H = self.decoder2(conc2, H)            
        
        upc3, H = self.upconv3(dec2, H)            
        conc3 = torch.cat([enc1, upc3], dim=1)       
        dec3, H = self.decoder3(conc3, H)            

        # projection
        dec3_project = self.group_pool(dec3)          
        
        # segmentation prediction path   
        out = self.final_conv(dec3_project)           
        out = self.final_conv_activation(out)     
        
        return out




    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.contiguous()
        y = y.contiguous()
        
        preds = self(x)
        loss = self.compute_loss(preds, y)
        
        acc = self.accuracy(preds, y)
        
        self.log("train_loss", loss, prog_bar=True, on_epoch=False, on_step=True)
        self.log("train_acc", acc, prog_bar=True, on_epoch=False, on_step=True)
        return loss


    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.contiguous()
        y = y.contiguous()
        
        preds = self(x)
        loss = self.compute_loss(preds, y)
        
        acc = self.accuracy(preds, y)
        
        self.log("valid_loss", loss, prog_bar=True, on_epoch=True, on_step=True)
        self.log("valid_acc", acc, prog_bar=True, on_epoch=True, on_step=True)
        return loss



    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.learning_rate)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 
                                                         mode='min', 
                                                         patience=3, 
                                                         factor=0.2
                                                         )
        return {"optimizer": optimizer, 
                "lr_scheduler": {"scheduler": scheduler, "monitor": "valid_loss"}}




    def compute_loss(self, preds, targets):

        loss = IoU_Loss_torch()(preds, targets)

        return loss



    def accuracy(self, preds, targets):
        
        # Convert predictions to binary (0 or 1)
        preds = preds > 0.5
        targets = targets > 0.5
        
        # Flatten the tensors to compute accuracy
        preds = preds.view(preds.size(0), -1)
        targets = targets.view(targets.size(0), -1)

   
        # Calculate accuracy per image in the batch
        correct_per_image = (preds == targets).sum(dim=1).float()  # Correct predictions per image
        total_per_image = preds.size(1)  # Number of elements per image

        # Compute accuracy for each image in the batch
        accuracy_per_image = correct_per_image / total_per_image
        
        
        return accuracy_per_image.mean()