# v4

In [73]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam



def random_projection(matrix, k):
    """Random projection to reduce dimensionality of matrix to k dimensions."""
    random_matrix = torch.randn(matrix.size(-1), k, device=matrix.device)
    return torch.matmul(matrix, random_matrix)

def cur_decomposition(matrix, k):
    """Applies CUR decomposition on each head for each batch."""
    batch_size, seq_length, heads, dim = matrix.shape
    # Ensuring k does not exceed dimensions
    k = min(k, dim)
    C = torch.zeros(batch_size, seq_length, heads, k, device=matrix.device)
    R = torch.zeros(batch_size, k, heads, dim, device=matrix.device)
    
    for b in range(batch_size):
        for h in range(heads):
            col_indices = np.random.choice(dim, k, replace=False)
            row_indices = np.random.choice(seq_length, k, replace=False)
            C[b, :, h] = matrix[b, :, h, col_indices]
            R[b, :, h] = matrix[b, row_indices, h]
    return C, R

class PartitionedLinformerAttentionACT(nn.Module):
    def __init__(self, embed_size, heads, sequence_length, projection_dim, partition_size):
        super().__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.sequence_length = sequence_length
        self.projection_dim = projection_dim
        self.partition_size = partition_size
        self.head_dim = embed_size // heads

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.projection_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)
        self.ponder = nn.Linear(self.head_dim, 1, bias=True)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        N, seq_length, _ = x.shape
        print(f"x shape as input to forward: {x.shape}")
        print(f"embed size as input to forward: {self.embed_size}")

        x_reshaped = x.view(N, seq_length, self.heads, self.head_dim)
        print(f"x_reshaped : {x_reshaped.shape}")

        values = self.values(x_reshaped).view(N, seq_length, self.heads, self.head_dim)
        keys = self.keys(x_reshaped).view(N, seq_length, self.heads, self.projection_dim)
        queries = self.queries(x_reshaped).view(N, seq_length, self.heads, self.projection_dim)

        keys = random_projection(keys, self.projection_dim // 2)
        queries = random_projection(queries, self.projection_dim // 2)

        attention_scores = torch.zeros(N, self.heads, seq_length, self.projection_dim // 2, device=x.device)

        for i in range(0, seq_length, self.partition_size):
            partition_start = i
            partition_end = min(i + self.partition_size, seq_length)
            keys_partition = keys[:, partition_start:partition_end, :, :]
            queries_partition = queries[:, partition_start:partition_end, :, :]

            C_keys, R_queries = cur_decomposition(keys_partition, k=keys_partition.size(-1) // 2)

            ponder_scores = torch.zeros(N, self.heads, partition_end - partition_start, 1, device=x.device)
            for h in range(self.heads):
                # Flatten the input tensor for each head across batch and sequence dimensions
                # while keeping the head dimension intact for linear layer processing
                head_queries = queries_partition[:, :, h, :].reshape(-1, self.head_dim)

                # Apply the ponder network
                head_ponder_scores = self.sigmoid(self.ponder(head_queries)).view(N, partition_end - partition_start, 1)

                # Reshape head_ponder_scores back to match the original batch and partition sizes
                # and expand the last dimension to match the expected shape for broadcasting
                head_ponder_scores_reshaped = head_ponder_scores.view(N, partition_end - partition_start, -1).expand(-1, -1, self.head_dim)

                # Assign the reshaped ponder scores back to the corresponding head in the ponder_scores tensor
                ponder_scores[:, :, h, :] = head_ponder_scores_reshaped.squeeze(-1)

            energy = torch.einsum('bnhdp,bnhpd->bnhdp', queries_partition, C_keys.transpose(-2, -1))
            attention = F.softmax(energy, dim=-1) * ponder_scores
            attention_scores[:, :, partition_start:partition_end, :] = attention

        out = torch.einsum('bnhdp,bnhpd->bnshp', attention_scores, values).reshape(N, seq_length, -1)
        return self.fc_out(out)




# Example usage

model = PartitionedLinformerAttentionACT(embed_size=512, heads=8, sequence_length=1024, projection_dim=256, partition_size=128)
input_tensor = torch.rand(2, model.sequence_length, model.embed_size)
output = model(input_tensor)
print(output.shape)  # Expected output shape: [2, sequence_length, embed_size]


# Create a DataLoader with pinned memory
class MyDataset(Dataset):

    def __init__(self):

        # Initialize your data here

        pass

    

    def __len__(self):

        # Return the size of your dataset

        return 100  # Example size



    def __getitem__(self, idx):

        # Implement logic to get a single item at a given index

        # For simplicity, let's return random tensors

        return torch.randn(10, 64), torch.randint(0, 2, (10,))  # Example data

dataset = MyDataset()
loader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
optimizer = Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for epoch in range(10):  # Example: 10 epochs
    for inputs, targets in loader:
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad()
        with autocast():
            predictions = model(inputs)
            loss = nn.functional.cross_entropy(predictions, targets)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()



x shape as input to forward: torch.Size([2, 1024, 512])
embed size as input to forward: 512
x_reshaped : torch.Size([2, 1024, 8, 64])


RuntimeError: shape '[2, 128, 1]' is invalid for input of size 512