In [9]:
import cProfile
import os
import pstats
import time

import psutil
import torch
import torch.nn as nn
import torch.optim as optim

from panther.nn import Performers


In [10]:
def profile(func, *args):
    with cProfile.Profile() as pr:
        rets = func(*args)
    results = pstats.Stats(pr).sort_stats(pstats.SortKey.TIME)
    results.print_stats()
    return rets

In [11]:
# test model with iscausal =True

mha = Performers(128, 8, 128, iscausal=True)
optimizer = optim.Adam(mha.parameters(), lr=0.01)
for _ in range(100):
    optimizer.zero_grad()
    x = torch.randn(128, 8, 128)
    result, _ = mha(x, x, x)
    # test backward
    result.sum().backward()
    optimizer.step()

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")


# Sample synthetic dataset (sequence classification)
def generate_data(num_samples=10, seq_len=3000, dim=32, num_classes=2):
    X = torch.randn(num_samples, seq_len, dim, device=device)  # Random embeddings
    y = torch.randint(0, num_classes, (num_samples,), device=device)  # Random labels
    return X, y


def get_cpu_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss  # in bytes


# Define a simple Transformer-like model with custom Multihead Attention
class CustomTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_classes, num_rand_features=None):
        super().__init__()
        if num_rand_features is None:
            self.attention = nn.MultiheadAttention(
                embed_dim, num_heads, batch_first=True
            )
        else:
            self.attention = Performers(
                embed_dim,
                num_heads,
                num_random_features=num_rand_features,
                kernel_fn="softmax",
            )
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.attention(x, x.clone(), x.clone())[0]  # Multihead Attention
        x = x.mean(dim=1)  # Pooling
        return self.fc(x)


def train_model(model, X_train, y_train, epochs=5, lr=1e-3):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    model.train()
    start_time = time.time()

    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats(device)
    else:
        mem_before = get_cpu_memory_usage()

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()

    training_time = time.time() - start_time

    if device.type == "cuda":
        mem_usage = torch.cuda.max_memory_allocated(device)
    else:
        mem_usage = get_cpu_memory_usage() - mem_before

    return loss.item(), training_time, mem_usage


# Forward pass benchmarking
def measure_forward_pass(model, X):
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        if device.type == "cuda":
            torch.cuda.reset_peak_memory_stats(device)
            torch.cuda.synchronize()
        else:
            mem_before = get_cpu_memory_usage()

        start_time = time.time()
        _ = model(X)
        if device.type == "cuda":
            torch.cuda.synchronize()
        forward_time = time.time() - start_time

        if device.type == "cuda":
            mem_usage = torch.cuda.max_memory_allocated(device)
        else:
            mem_usage = get_cpu_memory_usage() - mem_before

    return forward_time, mem_usage


# Hyperparameters
embed_dim = 32
num_heads = 4
num_classes = 2


# Generate synthetic dataset
def fun(s):
    X_train, y_train = generate_data(seq_len=s)

    # Initialize models
    custom_attention_model = CustomTransformer(
        embed_dim, num_heads, num_classes, num_rand_features=128
    )
    torch_attention_model = CustomTransformer(embed_dim, num_heads, num_classes)
    # Compare models
    # torch_loss, torch_time, torch_mem = profile(
    #     train_model, torch_attention_model, X_train, y_train
    # )
    # print("---------------------------------------------------")
    # custom_loss, custom_time, custom_mem = profile(
    #     train_model, custom_attention_model, X_train, y_train
    # )
    # print("---------------------------------------------------")
    torch_loss, torch_time, torch_mem = train_model(
        torch_attention_model, X_train, y_train
    )
    custom_loss, custom_time, custom_mem = train_model(
        custom_attention_model, X_train, y_train
    )
    custom_forward_time, custom_forward_mem = measure_forward_pass(
        custom_attention_model, X_train
    )
    torch_forward_time, torch_forward_mem = measure_forward_pass(
        torch_attention_model, X_train
    )
    # Print results
    print(
        f"Custom Attention - Loss: {custom_loss:.4f}, Time: {custom_time:.4f}s, Memory: {custom_mem / 1024:.2f} KB,Memory per length: {custom_mem / 1024 / s:.2f} KB"
    )
    print(
        f"Torch Attention - Loss: {torch_loss:.4f}, Time: {torch_time:.4f}s, Memory: {torch_mem / 1024:.2f} KB,Memory per length: {torch_mem / 1024 / s:.2f} KB"
    )
    print(
        f"Custom Attention - Forward Time: {custom_forward_time:.4f}s, Forward Memory: {custom_forward_mem / 1024:.2f} KB,Memory per length: {custom_forward_mem / 1024 / s:.2f} KB"
    )
    print(
        f"Torch Attention - Forward Time: {torch_forward_time:.4f}s, Forward Memory: {torch_forward_mem / 1024:.2f} KB,Memory per length: {torch_forward_mem / 1024 / s:.2f} KB"
    )


In [13]:
for i in [10, 100, 1000, 3000, 4000]:
    print(f"the sequence length is {i}")
    fun(i)

the sequence length is 10
Custom Attention - Loss: 0.4705, Time: 0.1523s, Memory: 19102.00 KB,Memory per length: 1910.20 KB
Torch Attention - Loss: 0.6054, Time: 0.4451s, Memory: 17885.50 KB,Memory per length: 1788.55 KB
Custom Attention - Forward Time: 0.0010s, Forward Memory: 17842.50 KB,Memory per length: 1784.25 KB
Torch Attention - Forward Time: 0.0010s, Forward Memory: 17871.00 KB,Memory per length: 1787.10 KB
the sequence length is 100
Custom Attention - Loss: 0.4760, Time: 0.0255s, Memory: 32248.00 KB,Memory per length: 322.48 KB
Torch Attention - Loss: 0.6600, Time: 0.0285s, Memory: 24541.50 KB,Memory per length: 245.41 KB
Custom Attention - Forward Time: 0.0010s, Forward Memory: 25610.00 KB,Memory per length: 256.10 KB
Torch Attention - Forward Time: 0.0010s, Forward Memory: 21433.50 KB,Memory per length: 214.34 KB
the sequence length is 1000
Custom Attention - Loss: 0.5862, Time: 0.0280s, Memory: 170290.50 KB,Memory per length: 170.29 KB
Torch Attention - Loss: 0.6569, Time: