In [28]:
import time

import torch
import torch.nn as nn

from panther.nn import SKConv2d

device = torch.device("cpu")
device = torch.device("cuda")


class CNN(nn.Module):
    def __init__(self, use_custom=False):
        super(CNN, self).__init__()
        self.conv1 = (
            SKConv2d(
                40, 800, kernel_size=3, stride=1, padding=1, num_terms=2, low_rank=4
            )
            if use_custom
            else nn.Conv2d(40, 800, kernel_size=3, stride=1, padding=1)
        )

    def forward(self, x):
        return self.conv1(x)

In [29]:
def time_forward_pass(model, input_tensor, warmup=10, repeat=100):
    model.to(device)
    input_tensor = input_tensor.to(device)
    model.eval()  # no dropout/batchnorm updates

    # Warm-up (especially important for CUDA)
    with torch.no_grad():
        for _ in range(warmup):
            _ = model(input_tensor)

    # Actual timing
    torch.cuda.synchronize() if device.type == "cuda" else None
    start = time.time()

    with torch.no_grad():
        for _ in range(repeat):
            _ = model(input_tensor)
    print("Output shape:", _.shape)

    torch.cuda.synchronize() if device.type == "cuda" else None
    end = time.time()

    avg_time = (end - start) / repeat

    return avg_time


# Compare custom vs PyTorch Conv2D
def compare_models():
    d = torch.randn(10, 40, 100, 100).to(device)
    # Train model with PyTorch Conv2D
    model = CNN(use_custom=True)
    t = time_forward_pass(model, d, warmup=10, repeat=100)
    print(f"Custom Conv2D forward pass time: {t:.6f} seconds")
    model = CNN(use_custom=False)
    t = time_forward_pass(model, d, warmup=10, repeat=100)
    print(f"PyTorch Conv2D forward pass time: {t:.6f} seconds")

In [30]:
compare_models()
# 1:26 m:sec for normal conv2d
# 1:38 m:sec for custom conv2d

Output shape: torch.Size([10, 800, 100, 100])
Custom Conv2D forward pass time: 0.020260 seconds
Output shape: torch.Size([10, 800, 100, 100])
PyTorch Conv2D forward pass time: 0.020740 seconds
