# Quantum-Inspired Sparse Attention - Benchmark Notebook
This notebook compares classical and quantum-inspired attention mechanisms on a toy classification task.

## Import Libraries
Load required PyTorch modules and utilities for model building and benchmarking.

In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
import random
import numpy as np

## Classical Multi-Head Attention
Defines the standard multi-head attention block used in Transformer models.

In [None]:
# Classical Multi-Head Attention Block
class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, num_heads: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        assert d_model % num_heads == 0, "d_model is not divisible by num_heads"

        self.d_k = d_model // num_heads # Dimension of each head
        self.w_q = nn.Linear(d_model, d_model) # Linear layer for queries
        self.w_k = nn.Linear(d_model, d_model) # Linear layer for keys
        self.w_v = nn.Linear(d_model, d_model) # Linear layer for values

        self.w_o = nn.Linear(d_model, d_model) # Linear layer for output
        self.dropout = nn.Dropout(dropout) # Dropout layer

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]

        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k) # Scaled dot-product attention
        if mask is not None:
            attention_scores.masked_fill_(mask == 0, float('-inf')) # Apply mask to attention scores
        attention_scores = attention_scores.softmax(dim = -1) # Softmax to get attention weights
        if dropout is not None:
            attention_scores = dropout(attention_scores)

        # 1st val in tuple - Compute the weighted sum of values based on attention scores
        # 2nd val in tuple - visualizing the attention weights
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # Linear transformation for queries
        key = self.w_k(k) # Linear transformation for keys
        value = self.w_v(v) # Linear transformation for values

        # Reshape and transpose the tensors to prepare for multi-head attention
        query = query.view(query.shape[0], query.shape[1], self.num_heads, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.num_heads, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.num_heads, self.d_k).transpose(1, 2)

        # Apply attention mechanism
        x, attn = self.attention(query, key, value, mask, self.dropout)

        # Reshape the output tensor back to the original shape
        x = x.transpose(1, 2).contiguous().view(x.size(0), x.size(2), self.num_heads * self.d_k)
        
        # Apply the final linear transformation to the output
        return self.w_o(x)

## Create Simple Sample Dataset
Generates 8 random samples with 4 tokens each and binary labels for classification.

In [None]:
# Dummy dataset: 8 samples, 4 tokens each, 8-dim embeddings
X = torch.rand(8, 4, 8)  # batch_size=8, seq_len=4, d_model=8
y = torch.randint(0, 2, (8,))  # Binary labels

## Classical Attention Model
Defines a classifier using the classical multi-head attention mechanism.

In [None]:
# Simple classifier using classical attention
class ClassicalClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.attn = MultiHeadAttentionBlock(d_model=8, num_heads=2, dropout=0.1)
        self.fc = nn.Linear(8, 2)

    def forward(self, x):
        out = self.attn(x, x, x, None)
        out = out.mean(dim=1)  # Global average pooling over sequence
        return self.fc(out)

## Quantum-Inspired Attention Function
Simulates quantum behavior by using random softmax weights instead of dot-product scores (I didn't use a companies simulator).

In [None]:
# Simulated quantum-inspired attention (randomized weights)
def quantum_inspired_attention(query, key, value):
    B, H, T, D = query.shape
    scores = torch.rand(B, H, T, T)
    weights = torch.softmax(scores, dim=-1)
    return weights @ value

## Quantum-Inspired Model
Builds a classifier that uses the quantum-inspired attention for feature extraction.

In [None]:
# Quantum-inspired classifier
class QuantumInspiredClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.w_q = nn.Linear(8, 8)
        self.w_k = nn.Linear(8, 8)
        self.w_v = nn.Linear(8, 8)
        self.fc = nn.Linear(8, 2)

    def forward(self, x):
        q = self.w_q(x).view(x.size(0), 2, 4, 4).transpose(1, 2)
        k = self.w_k(x).view(x.size(0), 2, 4, 4).transpose(1, 2)
        v = self.w_v(x).view(x.size(0), 2, 4, 4).transpose(1, 2)
        out = quantum_inspired_attention(q, k, v)
        out = out.transpose(1, 2).contiguous().view(x.size(0), 4, 8)
        out = out.mean(dim=1)
        return self.fc(out)

## Evaluation Function
Defines a function to compute accuracy and runtime of a given model on the dataset.

In [None]:
# Evaluation function
def evaluate(model, X, y):
    model.eval()
    start = time.time()
    with torch.no_grad():
        logits = model(X)
        preds = torch.argmax(logits, dim=1)
        accuracy = (preds == y).float().mean().item()
    elapsed = time.time() - start
    return accuracy, elapsed

## Run and Compare Models
Evaluates both models and prints out their accuracy and inference time.

In [None]:
# Initialize and evaluate both models
classical_model = ClassicalClassifier()
quantum_model = QuantumInspiredClassifier()

acc_classical, time_classical = evaluate(classical_model, X, y)
acc_quantum, time_quantum = evaluate(quantum_model, X, y)

print(f"Classical Accuracy: {acc_classical*100:.1f}% | Time: {time_classical:.4f}s")
print(f"Quantum-Inspired Accuracy: {acc_quantum*100:.1f}% | Time: {time_quantum:.4f}s")

Classical Accuracy: 37.5% | Time: 0.0025s
Quantum-Inspired Accuracy: 37.5% | Time: 0.0021s
