# SparseFlash2LinformerAttention

In [139]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam



def random_projection(matrix, k):
    """Random projection to reduce dimensionality of matrix to k dimensions."""
    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)
    return torch.matmul(matrix, random_matrix)

def cur_decomposition(matrix, projection_dim):  # Change 'k' argument
    """Applies CUR decomposition with C matrix dimension aligned to projection dimension."""
    batch_size, seq_length, heads, dim = matrix.shape
    k = min(projection_dim // 2, dim)  # Use projection dimension to determine 'k'
    C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
    R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)

    for b in range(batch_size):
        for h in range(heads):
            col_indices = np.random.choice(dim, k, replace=False)
            row_indices = np.random.choice(seq_length, k, replace=False)
            C[b, :, h] = matrix[b, :, h, col_indices]
            R[b, :, h] = matrix[b, row_indices, h]
    return C, R

class PartitionedLinformerAttentionACT(nn.Module):
    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.sequence_length = sequence_length
        self.projection_dim = projection_dim
        self.partition_size = partition_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.value_projection = nn.Linear(self.head_dim, self.projection_dim//2)  # To project values to match dimensions
        self.ponder = nn.Linear(self.partition_size, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        print(f"BEGIN FORWARD")
        print(f"x input shape: {x.shape}")
        N, seq_length, _ = x.shape
        x_reshaped = x.view(N, seq_length, self.heads, self.head_dim)
        print(f"x_reshaped shape: {x_reshaped.shape}")

        values = self.values(x_reshaped)
        queries = random_projection(self.queries(x_reshaped), self.projection_dim // 2)
        keys = random_projection(self.keys(x_reshaped), self.projection_dim // 2 )
        print("values :", values.shape)
        print("queries :", queries.shape)
        print("keys :", keys.shape)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)
        print(f"attention_scores after random projection: {attention_scores.shape}")

        for i in range(0, seq_length, self.partition_size):
            print(f"PARTITION START")
            partition_start = i
            partition_end = min(i + self.partition_size, seq_length)
            keys_part = keys[:, partition_start:partition_end, :, :]
            queries_part = queries[:, partition_start:partition_end, :, :]

            C_keys, R_queries = cur_decomposition(keys_part, self.projection_dim)
            print("C_keys shape before return:", C_keys.shape)

            ponder_scores = torch.zeros(N, self.heads, partition_end - partition_start, 1, device=x.device)
            print(f"Partition Start {i}, Partition End {partition_end} , ponder_scores: {ponder_scores.shape}")

            for h in range(self.heads):
                #print(f"HEADS START")
                head_queries = queries_part[:, :, h, :]
                #print(f"head_queries: {head_queries.shape}")
                head_ponder_scores = self.sigmoid(self.ponder(head_queries))
                #print(f"head_ponder_scores: {head_ponder_scores.shape}")
                ponder_scores[:, h, :, 0] = head_ponder_scores.squeeze(-1)

            # Correctly expand ponder_scores without adding an unnecessary dimension
            print("BEFORE 1ST EINSUM:")
            ponder_scores_permuted = ponder_scores.permute(0, 2, 1, 3)  # Move to [2, 128, 8, 1]
            print("ponder_scores_permuted shape:", ponder_scores_permuted.shape) 
            ponder_scores_broadcastable = ponder_scores_permuted.expand(-1, -1, -1, 128)  # Expand to [2, 128, 8, 128]            print(f"ponder_scores_broadcastable shape: {ponder_scores_broadcastable.shape}")
            print("ponder_scores_broadcastable shape:", ponder_scores_broadcastable.shape) 
            print("queries_part shape:", queries_part.shape) 
            print("C_keys shape:", C_keys.shape)
            energy = torch.einsum('bnhd,bnhk->bnhd', queries_part, C_keys)
            attention_weights = F.softmax(energy, dim=-1)
            print("AFTER 1ST EINSUM:")
            print("energy shape:", energy.shape) 
            print("attention_weights shape:", attention_weights.shape)
            attention = attention_weights * ponder_scores_broadcastable
            print("attention shape:", attention.shape)
            attention_corrected = attention.permute(0, 2, 1, 3)
            attention_scores[:, :, partition_start:partition_end, :] = attention_corrected

        values = values.permute(0, 2, 1, 3)  # Swap heads and seq_length to bring heads next to head_dim
        print("values shape:", values.shape)
        values = values.reshape(-1, self.head_dim)  # Flatten to [N*heads*seq_length, head_dim] for linear layer
        print("values.reshape(-1, self.head_dim) shape:", values.shape)
        projected_values = self.value_projection(values)  # Now [N*heads*seq_length, projection_dim / 2]
        print("self.value_projection(values) shape:", projected_values.shape)
        projected_values = projected_values.view(N, self.heads, seq_length, self.projection_dim // 2)
        print("projected_values shape:", projected_values.shape)

        print(f"2ND EINSUM")
        # Combine attention_scores and projected_values then pass through the final linear layer
        out = torch.einsum('bnhp,bnhp->bnh', attention_scores, projected_values)
        print("out shape:", out.shape)
       
        return out


# Example usage
model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)
input_tensor = torch.rand(2, model.sequence_length, model.embed_size)
output = model(input_tensor)
print(output.shape)  


# Create a DataLoader with pinned memory
class MyDataset(Dataset):

    def __init__(self):
        # Initialize your data here
        pass
    def __len__(self):
        # Return the size of your dataset
        return 100  # Example size
    def __getitem__(self, idx):
        # Implement logic to get a single item at a given index
        # For simplicity, let's return random tensors
        return torch.randn(10, 64), torch.randint(0, 2, (10,))  # Example data

dataset = MyDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for epoch in range(10):  # Example: 10 epochs
    for inputs, targets in loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad()
        with autocast():
            predictions = model(inputs)
            loss = nn.functional.cross_entropy(predictions, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

BEGIN FORWARD
x input shape: torch.Size([2, 1024, 512])
x_reshaped shape: torch.Size([2, 1024, 8, 64])
values : torch.Size([2, 1024, 8, 64])
queries : torch.Size([2, 1024, 8, 128])
keys : torch.Size([2, 1024, 8, 128])
attention_scores after random projection: torch.Size([2, 8, 1024, 128])
PARTITION START
C_keys shape before return: torch.Size([2, 128, 8, 128])
Partition Start 0, Partition End 128 , ponder_scores: torch.Size([2, 8, 128, 1])
BEFORE 1ST EINSUM:
ponder_scores_permuted shape: torch.Size([2, 128, 8, 1])
ponder_scores_broadcastable shape: torch.Size([2, 128, 8, 128])
queries_part shape: torch.Size([2, 128, 8, 128])
C_keys shape: torch.Size([2, 128, 8, 128])
AFTER 1ST EINSUM:
energy shape: torch.Size([2, 128, 8, 128])
attention_weights shape: torch.Size([2, 128, 8, 128])
attention shape: torch.Size([2, 128, 8, 128])
PARTITION START
C_keys shape before return: torch.Size([2, 128, 8, 128])
Partition Start 128, Partition End 256 , ponder_scores: torch.Size([2, 8, 128, 1])
BEFORE 

RuntimeError: shape '[32, 10, 8, 64]' is invalid for input of size 20480