In [2]:
import torch
import torch.nn as nn

embed_dim = 128
num_heads = 8
seq_length = 10
batch_size = 16
mha = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
q = torch.rand(batch_size, seq_length, embed_dim)
k = torch.rand(batch_size, seq_length, embed_dim)
v = torch.rand(batch_size, seq_length, embed_dim)
output, attn_output_weights = mha(q, k, v)
print(output.shape)
print(attn_output_weights.shape)

torch.Size([16, 10, 128])
torch.Size([16, 10, 10])


In [None]:
import os
import time

import psutil
import torch
import torch.nn as nn
import torch.optim as optim

from panther.nn import Performers

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")


# Sample synthetic dataset (sequence classification)
def generate_data(num_samples=10, seq_len=1000, dim=32, num_classes=2):
    X = torch.randn(num_samples, seq_len, dim, device=device)  # Random embeddings
    y = torch.randint(0, num_classes, (num_samples,), device=device)  # Random labels
    return X, y


def get_cpu_memory_usage():
    process = psutil.Process(os.getpid())
    return process.memory_info().rss  # in bytes


# Define a simple Transformer-like model with custom Multihead Attention
class CustomTransformer(nn.Module):
    def __init__(self, embed_dim, num_heads, num_classes, num_rand_features=None):
        super().__init__()
        if num_rand_features is None:
            self.attention = nn.MultiheadAttention(
                embed_dim, num_heads, batch_first=True
            )
        else:
            self.attention = Performers(
                embed_dim, num_heads, num_random_features=num_rand_features
            )
        self.fc = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        x = self.attention(x, x.clone(), x.clone())[0]  # Multihead Attention
        x = x.mean(dim=1)  # Pooling
        return self.fc(x)


def train_model(model, X_train, y_train, epochs=5, lr=1e-3):
    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    model.train()
    start_time = time.time()

    if device.type == "cuda":
        torch.cuda.reset_peak_memory_stats(device)
    else:
        mem_before = get_cpu_memory_usage()

    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = model(X_train)
        loss = criterion(outputs, y_train)
        loss.backward()
        optimizer.step()

    training_time = time.time() - start_time

    if device.type == "cuda":
        mem_usage = torch.cuda.max_memory_allocated(device)
    else:
        mem_usage = get_cpu_memory_usage() - mem_before

    return loss.item(), training_time, mem_usage


# Forward pass benchmarking
def measure_forward_pass(model, X):
    model = model.to(device)
    model.eval()
    with torch.no_grad():
        if device.type == "cuda":
            torch.cuda.reset_peak_memory_stats(device)
            torch.cuda.synchronize()
        else:
            mem_before = get_cpu_memory_usage()

        start_time = time.time()
        _ = model(X)
        if device.type == "cuda":
            torch.cuda.synchronize()
        forward_time = time.time() - start_time

        if device.type == "cuda":
            mem_usage = torch.cuda.max_memory_allocated(device)
        else:
            mem_usage = get_cpu_memory_usage() - mem_before

    return forward_time, mem_usage


# Hyperparameters
embed_dim = 32
num_heads = 4
num_classes = 2

# Generate synthetic dataset
X_train, y_train = generate_data()

# Initialize models
custom_attention_model = CustomTransformer(
    embed_dim, num_heads, num_classes, num_rand_features=128
)
torch_attention_model = CustomTransformer(embed_dim, num_heads, num_classes)

# Compare models
torch_loss, torch_time, torch_mem = train_model(torch_attention_model, X_train, y_train)
custom_loss, custom_time, custom_mem = train_model(
    custom_attention_model, X_train, y_train
)
torch_loss, torch_time, torch_mem = train_model(torch_attention_model, X_train, y_train)
custom_loss, custom_time, custom_mem = train_model(
    custom_attention_model, X_train, y_train
)
custom_forward_time, custom_forward_mem = measure_forward_pass(
    custom_attention_model, X_train
)
torch_forward_time, torch_forward_mem = measure_forward_pass(
    torch_attention_model, X_train
)

# Print results
print(
    f"Custom Attention - Loss: {custom_loss:.4f}, Time: {custom_time:.4f}s, Memory: {custom_mem / 1024:.2f} KB"
)
print(
    f"Torch Attention - Loss: {torch_loss:.4f}, Time: {torch_time:.4f}s, Memory: {torch_mem / 1024:.2f} KB"
)
print(
    f"Custom Attention - Forward Time: {custom_forward_time:.4f}s, Forward Memory: {custom_forward_mem / 1024:.2f} KB"
)
print(
    f"Torch Attention - Forward Time: {torch_forward_time:.4f}s, Forward Memory: {torch_forward_mem / 1024:.2f} KB"
)


Custom Attention - Loss: 0.5677, Time: 0.7361s, Memory: 24.00 KB
Torch Attention - Loss: 0.6900, Time: 0.7982s, Memory: 147808.00 KB
Custom Attention - Forward Time: 0.0560s, Forward Memory: 0.00 KB
Torch Attention - Forward Time: 0.0700s, Forward Memory: 0.00 KB
