In [1]:
import torch
import torch.nn as nn
import numpy as np

# Define the AttentionRNNCell (from your implementation)
class AttentionRNNCell(nn.Module):
    def __init__(self, d_model):
        super(AttentionRNNCell, self).__init__()
        self.d_model = d_model  # Dimensionality of keys/queries/values

    def forward_iterative(self, q, k, v, prev_a=None, prev_c=None, prev_m=None):
        batch_size = q.size(0)

        # Initialize previous states if not provided
        if prev_a is None:
            prev_a = torch.zeros(batch_size, self.d_model).to(k.device)
        if prev_c is None:
            prev_c = torch.zeros(batch_size, 1).to(k.device)
        if prev_m is None:
            prev_m = torch.full((batch_size, 1), 0).to(k.device)

        # Compute scores s_k = q . k^T
        s_k = torch.sum(q * k, dim=-1, keepdim=True)  # Shape: (batch_size, 1)

        # Update m_k (max cumulative score)
        m_k = torch.max(s_k, prev_m)

        # Compute exp terms for stability
        exp_term1 = torch.exp(prev_m - m_k)  # Shape: (batch_size, 1)
        exp_term2 = torch.exp(s_k - m_k)    # Shape: (batch_size, 1)

        # Update a_k and c_k
        a_k = prev_a * exp_term1 + v * exp_term2  # Shape: (batch_size, d_model)
        c_k = prev_c * exp_term1 + exp_term2      # Shape: (batch_size, 1)

        return a_k, c_k, m_k

# Configuration
batch_size = 1  # Number of audio sequences to process at once
input_dim = 10  # Number of audio frames in the sequence
d_model = 40    # Dimensionality of each audio frame (e.g., MFCCs)

# Simulate audio frames (e.g., MFCCs or spectrogram features)
# Each frame is a vector of size d_model
audio_frames = torch.randn(input_dim, batch_size, d_model)  # Shape: (input_dim, batch_size, d_model)

# Initialize a fixed query vector (learned during training)
query = torch.randn(batch_size, d_model)  # Shape: (batch_size, d_model)

# Initialize AttentionRNNCell
cell = AttentionRNNCell(d_model)

# Initialize hidden states
a_k, c_k, m_k = None, None, None

# Process audio frames iteratively (simulating a streaming scenario)
for i in range(input_dim):
    # Get the current frame (key and value)
    frame = audio_frames[i, :, :]  # Shape: (batch_size, d_model)

    # Perform a step of attention
    a_k, c_k, m_k = cell.forward_iterative(query, frame, frame, a_k, c_k, m_k)

    # Print intermediate results (optional)
    print(f"Frame {i+1}:")
    print("a_k (weighted sum of values):\n", a_k)
    print("c_k (normalization term):\n", c_k)
    print("m_k (max score):\n", m_k)
    print("-" * 50)

# Final attention output
attention_output = a_k / c_k  # Shape: (batch_size, d_model)
print("Final Attention Output:\n", attention_output)

Frame 1:
a_k (weighted sum of values):
 tensor([[ 0.0208, -0.0104,  0.0255, -0.0144, -0.0352, -0.0093, -0.1045, -0.0262,
          0.0012,  0.0876, -0.0329, -0.0573,  0.0362, -0.0485,  0.0376, -0.0534,
         -0.0725,  0.0190, -0.0151, -0.0662, -0.1108,  0.0521,  0.0771,  0.1231,
          0.0989,  0.0649, -0.0252,  0.0742, -0.0007, -0.0008,  0.0393,  0.0540,
         -0.0276, -0.0059, -0.0349,  0.0800, -0.0888,  0.0413,  0.0895, -0.0748]])
c_k (normalization term):
 tensor([[0.0752]])
m_k (max score):
 tensor([[0.]])
--------------------------------------------------
Frame 2:
a_k (weighted sum of values):
 tensor([[ 0.0175, -0.0087,  0.0243, -0.0116, -0.0328, -0.0033, -0.0994, -0.0298,
          0.0064,  0.0942, -0.0316, -0.0524,  0.0383, -0.0468,  0.0425, -0.0486,
         -0.0831,  0.0199, -0.0154, -0.0665, -0.1138,  0.0535,  0.0723,  0.1237,
          0.1064,  0.0601, -0.0248,  0.0722,  0.0067,  0.0020,  0.0354,  0.0522,
         -0.0248, -0.0048, -0.0353,  0.0746, -0.0931,  0.04

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Define the AttentionRNNCell (from your implementation)
class AttentionRNNCell(nn.Module):
    def __init__(self, d_model):
        super(AttentionRNNCell, self).__init__()
        self.d_model = d_model  # Dimensionality of keys/queries/values

    def forward_iterative(self, q, k, v, prev_a=None, prev_c=None, prev_m=None):
        batch_size = q.size(0)

        # Initialize previous states if not provided
        if prev_a is None:
            prev_a = torch.zeros(batch_size, self.d_model).to(k.device)
        if prev_c is None:
            prev_c = torch.zeros(batch_size, 1).to(k.device)
        if prev_m is None:
            prev_m = torch.full((batch_size, 1), 0).to(k.device)

        # Compute scores s_k = q . k^T
        s_k = torch.sum(q * k, dim=-1, keepdim=True)  # Shape: (batch_size, 1)

        # Update m_k (max cumulative score)
        m_k = torch.max(s_k, prev_m)

        # Compute exp terms for stability
        exp_term1 = torch.exp(prev_m - m_k)  # Shape: (batch_size, 1)
        exp_term2 = torch.exp(s_k - m_k)    # Shape: (batch_size, 1)

        # Update a_k and c_k
        a_k = prev_a * exp_term1 + v * exp_term2  # Shape: (batch_size, d_model)
        c_k = prev_c * exp_term1 + exp_term2      # Shape: (batch_size, 1)

        return a_k, c_k, m_k

# Define the full model
class AttentionRNNClassifier(nn.Module):
    def __init__(self, d_model, num_classes):
        super(AttentionRNNClassifier, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # Attention RNN Cell
        self.attention_cell = AttentionRNNCell(d_model)

        # Classification head
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Initialize query (learnable parameter)
        query = nn.Parameter(torch.randn(batch_size, d_model)).to(x.device)

        # Initialize hidden states
        a_k, c_k, m_k = None, None, None

        # Process the sequence iteratively
        for i in range(seq_len):
            frame = x[:, i, :]  # Get the i-th frame in the sequence
            a_k, c_k, m_k = self.attention_cell.forward_iterative(query, frame, frame, a_k, c_k, m_k)

        # Final attention output
        attention_output = a_k / c_k  # Shape: (batch_size, d_model)

        # Pass through the classification head
        logits = self.fc(attention_output)  # Shape: (batch_size, num_classes)

        return logits

# Simulate a dataset
class AudioDataset(Dataset):
    def __init__(self, num_samples, seq_len, d_model, num_classes):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.d_model = d_model
        self.num_classes = num_classes

        # Simulate random audio sequences and labels
        self.data = torch.randn(num_samples, seq_len, d_model)  # Random audio frames
        self.labels = torch.randint(0, num_classes, (num_samples,))  # Random labels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Configuration
batch_size = 32
seq_len = 10  # Number of audio frames per sequence
d_model = 40  # Dimensionality of each audio frame (e.g., MFCCs)
num_classes = 5  # Number of classes for classification
num_samples = 1000  # Number of samples in the dataset
num_epochs = 10
learning_rate = 0.001

# Create dataset and dataloader
dataset = AudioDataset(num_samples, seq_len, d_model, num_classes)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss function, and optimizer
model = AttentionRNNClassifier(d_model, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, labels) in enumerate(dataloader):
        # Forward pass
        logits = model(data)
        loss = criterion(logits, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Print statistics
        running_loss += loss.item()
        if batch_idx % 10 == 9:  # Print every 10 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], "
                  f"Loss: {running_loss / 10:.4f}, Accuracy: {100 * correct / total:.2f}%")
            running_loss = 0.0
            correct = 0
            total = 0

# Test the model
model.eval()
test_dataset = AudioDataset(num_samples=100, seq_len=seq_len, d_model=d_model, num_classes=num_classes)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_dataloader:
        logits = model(data)
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Epoch [1/10], Batch [10/32], Loss: 1.7103, Accuracy: 19.69%
Epoch [1/10], Batch [20/32], Loss: 1.7789, Accuracy: 16.88%
Epoch [1/10], Batch [30/32], Loss: 1.7023, Accuracy: 21.56%
Epoch [2/10], Batch [10/32], Loss: 1.6910, Accuracy: 19.06%
Epoch [2/10], Batch [20/32], Loss: 1.6731, Accuracy: 25.00%
Epoch [2/10], Batch [30/32], Loss: 1.6674, Accuracy: 20.00%
Epoch [3/10], Batch [10/32], Loss: 1.7264, Accuracy: 16.25%
Epoch [3/10], Batch [20/32], Loss: 1.6854, Accuracy: 19.06%
Epoch [3/10], Batch [30/32], Loss: 1.6568, Accuracy: 21.88%
Epoch [4/10], Batch [10/32], Loss: 1.6798, Accuracy: 20.00%
Epoch [4/10], Batch [20/32], Loss: 1.6567, Accuracy: 20.31%
Epoch [4/10], Batch [30/32], Loss: 1.6862, Accuracy: 21.56%
Epoch [5/10], Batch [10/32], Loss: 1.6710, Accuracy: 23.12%
Epoch [5/10], Batch [20/32], Loss: 1.6683, Accuracy: 20.94%
Epoch [5/10], Batch [30/32], Loss: 1.6650, Accuracy: 20.00%
Epoch [6/10], Batch [10/32], Loss: 1.6069, Accuracy: 29.38%
Epoch [6/10], Batch [20/32], Loss: 1.635

In [3]:
from torchaudio.datasets import LIBRISPEECH

class LibriSpeechDataset(Dataset):
    def __init__(self, root, url):
        self.dataset = LIBRISPEECH(root=root, url=url)
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        waveform, sample_rate, utterance, speaker_id, chapter_id, utterance_id = self.dataset[idx]
        # Extract MFCCs or other features
        mfcc = torchaudio.transforms.MFCC()(waveform)
        return mfcc, speaker_id  # Use speaker_id as the label

In [4]:
class AttentionRNNClassifier(nn.Module):
    def __init__(self, d_model, num_classes):
        super(AttentionRNNClassifier, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # Learnable query
        self.query = nn.Parameter(torch.randn(1, d_model))

        # Attention RNN Cell
        self.attention_cell = AttentionRNNCell(d_model)

        # Classification head
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Expand query to match batch size
        query = self.query.expand(batch_size, -1)

        # Initialize hidden states
        a_k, c_k, m_k = None, None, None

        # Process the sequence iteratively
        for i in range(seq_len):
            frame = x[:, i, :]  # Get the i-th frame in the sequence
            a_k, c_k, m_k = self.attention_cell.forward_iterative(query, frame, frame, a_k, c_k, m_k)

        # Final attention output
        attention_output = a_k / c_k  # Shape: (batch_size, d_model)

        # Pass through the classification head
        logits = self.fc(attention_output)  # Shape: (batch_size, num_classes)

        return logits

In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].to(x.device)

# Add positional encoding to the model
class AttentionRNNClassifier(nn.Module):
    def __init__(self, d_model, num_classes):
        super(AttentionRNNClassifier, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Learnable query
        self.query = nn.Parameter(torch.randn(1, d_model))

        # Attention RNN Cell
        self.attention_cell = AttentionRNNCell(d_model)

        # Classification head
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Add positional encoding
        x = self.pos_encoder(x)

        # Expand query to match batch size
        query = self.query.expand(batch_size, -1)

        # Initialize hidden states
        a_k, c_k, m_k = None, None, None

        # Process the sequence iteratively
        for i in range(seq_len):
            frame = x[:, i, :]  # Get the i-th frame in the sequence
            a_k, c_k, m_k = self.attention_cell.forward_iterative(query, frame, frame, a_k, c_k, m_k)

        # Final attention output
        attention_output = a_k / c_k  # Shape: (batch_size, d_model)

        # Pass through the classification head
        logits = self.fc(attention_output)  # Shape: (batch_size, num_classes)

        return logits

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Define the AttentionRNNCell (from your implementation)
class AttentionRNNCell(nn.Module):
    def __init__(self, d_model):
        super(AttentionRNNCell, self).__init__()
        self.d_model = d_model  # Dimensionality of keys/queries/values

    def forward_iterative(self, q, k, v, prev_a=None, prev_c=None, prev_m=None):
        batch_size = q.size(0)

        # Initialize previous states if not provided
        if prev_a is None:
            prev_a = torch.zeros(batch_size, self.d_model).to(k.device)
        if prev_c is None:
            prev_c = torch.zeros(batch_size, 1).to(k.device)
        if prev_m is None:
            prev_m = torch.full((batch_size, 1), 0).to(k.device)

        # Compute scores s_k = q . k^T
        s_k = torch.sum(q * k, dim=-1, keepdim=True)  # Shape: (batch_size, 1)

        # Update m_k (max cumulative score)
        m_k = torch.max(s_k, prev_m)

        # Compute exp terms for stability
        exp_term1 = torch.exp(prev_m - m_k)  # Shape: (batch_size, 1)
        exp_term2 = torch.exp(s_k - m_k)    # Shape: (batch_size, 1)

        # Update a_k and c_k
        a_k = prev_a * exp_term1 + v * exp_term2  # Shape: (batch_size, d_model)
        c_k = prev_c * exp_term1 + exp_term2      # Shape: (batch_size, 1)

        return a_k, c_k, m_k

# Define positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].to(x.device)

# Define the full model
class AttentionRNNClassifier(nn.Module):
    def __init__(self, d_model, num_classes):
        super(AttentionRNNClassifier, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Learnable query
        self.query = nn.Parameter(torch.randn(1, d_model))

        # Attention RNN Cell
        self.attention_cell = AttentionRNNCell(d_model)

        # Classification head
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Add positional encoding
        x = self.pos_encoder(x)

        # Expand query to match batch size
        query = self.query.expand(batch_size, -1)

        # Initialize hidden states
        a_k, c_k, m_k = None, None, None

        # Process the sequence iteratively
        for i in range(seq_len):
            frame = x[:, i, :]  # Get the i-th frame in the sequence
            a_k, c_k, m_k = self.attention_cell.forward_iterative(query, frame, frame, a_k, c_k, m_k)

        # Final attention output
        attention_output = a_k / c_k  # Shape: (batch_size, d_model)

        # Pass through the classification head
        logits = self.fc(attention_output)  # Shape: (batch_size, num_classes)

        return logits

# Simulate a more structured dataset
class StructuredDataset(Dataset):
    def __init__(self, num_samples, seq_len, d_model, num_classes):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.d_model = d_model
        self.num_classes = num_classes

        # Simulate structured data (e.g., MFCCs)
        self.data = torch.randn(num_samples, seq_len, d_model)  # Random audio frames
        self.labels = torch.randint(0, num_classes, (num_samples,))  # Random labels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Configuration
batch_size = 64
seq_len = 20  # Number of audio frames per sequence
d_model = 40  # Dimensionality of each audio frame (e.g., MFCCs)
num_classes = 5  # Number of classes for classification
num_samples = 1000  # Number of samples in the dataset
num_epochs = 1000
learning_rate = 0.001

# Create dataset and dataloader
dataset = StructuredDataset(num_samples, seq_len, d_model, num_classes)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss function, and optimizer
model = AttentionRNNClassifier(d_model, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, labels) in enumerate(dataloader):
        # Forward pass
        logits = model(data)
        loss = criterion(logits, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Print statistics
        running_loss += loss.item()
        if batch_idx % 10 == 9:  # Print every 10 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], "
                  f"Loss: {running_loss / 10:.4f}, Accuracy: {100 * correct / total:.2f}%")
            running_loss = 0.0
            correct = 0
            total = 0

# Test the model
model.eval()
test_dataset = StructuredDataset(num_samples=200, seq_len=seq_len, d_model=d_model, num_classes=num_classes)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_dataloader:
        logits = model(data)
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Epoch [1/1000], Batch [10/16], Loss: 1.8247, Accuracy: 19.69%
Epoch [2/1000], Batch [10/16], Loss: 1.7364, Accuracy: 19.22%
Epoch [3/1000], Batch [10/16], Loss: 1.6970, Accuracy: 18.75%
Epoch [4/1000], Batch [10/16], Loss: 1.6747, Accuracy: 18.91%
Epoch [5/1000], Batch [10/16], Loss: 1.6774, Accuracy: 19.84%
Epoch [6/1000], Batch [10/16], Loss: 1.6766, Accuracy: 19.69%
Epoch [7/1000], Batch [10/16], Loss: 1.6509, Accuracy: 20.94%
Epoch [8/1000], Batch [10/16], Loss: 1.6509, Accuracy: 21.56%
Epoch [9/1000], Batch [10/16], Loss: 1.6447, Accuracy: 21.88%
Epoch [10/1000], Batch [10/16], Loss: 1.6205, Accuracy: 22.03%
Epoch [11/1000], Batch [10/16], Loss: 1.6279, Accuracy: 21.88%
Epoch [12/1000], Batch [10/16], Loss: 1.6261, Accuracy: 22.03%
Epoch [13/1000], Batch [10/16], Loss: 1.6063, Accuracy: 25.94%
Epoch [14/1000], Batch [10/16], Loss: 1.5992, Accuracy: 25.47%
Epoch [15/1000], Batch [10/16], Loss: 1.6054, Accuracy: 23.75%
Epoch [16/1000], Batch [10/16], Loss: 1.5985, Accuracy: 25.31%
E

In [11]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=10e-3)

# Training loop
for epoch in range(50):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, labels) in enumerate(dataloader):
        # Forward pass
        logits = model(data)
        loss = criterion(logits, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Print statistics
        running_loss += loss.item()
        if batch_idx % 10 == 9:  # Print every 10 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], "
                  f"Loss: {running_loss / 10:.4f}, Accuracy: {100 * correct / total:.2f}%")
            running_loss = 0.0
            correct = 0
            total = 0

# Test the model
model.eval()
test_dataset = StructuredDataset(num_samples=200, seq_len=seq_len, d_model=d_model, num_classes=num_classes)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_dataloader:
        logits = model(data)
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Epoch [1/1000], Batch [10/16], Loss: 1.4497, Accuracy: 38.44%
Epoch [2/1000], Batch [10/16], Loss: 1.4349, Accuracy: 38.44%
Epoch [3/1000], Batch [10/16], Loss: 1.4433, Accuracy: 35.62%
Epoch [4/1000], Batch [10/16], Loss: 1.4231, Accuracy: 39.53%
Epoch [5/1000], Batch [10/16], Loss: 1.4391, Accuracy: 36.25%
Epoch [6/1000], Batch [10/16], Loss: 1.4261, Accuracy: 38.91%
Epoch [7/1000], Batch [10/16], Loss: 1.4298, Accuracy: 37.97%
Epoch [8/1000], Batch [10/16], Loss: 1.4267, Accuracy: 39.69%
Epoch [9/1000], Batch [10/16], Loss: 1.4328, Accuracy: 37.66%
Epoch [10/1000], Batch [10/16], Loss: 1.4390, Accuracy: 38.12%
Epoch [11/1000], Batch [10/16], Loss: 1.4371, Accuracy: 38.12%
Epoch [12/1000], Batch [10/16], Loss: 1.4130, Accuracy: 41.72%
Epoch [13/1000], Batch [10/16], Loss: 1.4136, Accuracy: 39.84%
Epoch [14/1000], Batch [10/16], Loss: 1.4257, Accuracy: 39.69%
Epoch [15/1000], Batch [10/16], Loss: 1.4355, Accuracy: 38.44%
Epoch [16/1000], Batch [10/16], Loss: 1.4409, Accuracy: 38.12%
E

In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np

# Define the ParallelAttentionScan module
class ParallelAttentionScan(nn.Module):
    def __init__(self):
        super(ParallelAttentionScan, self).__init__()

    def combine(self, mA, uA, wA, mB, uB, wB):
        mAB = torch.max(mA, mB)
        expA = torch.exp(mA - mAB)
        expB = torch.exp(mB - mAB)
        uAB = uA * expA + uB * expB
        wAB = wA * expA + wB * expB
        return mAB, uAB, wAB

    def forward(self, q, k, v):
        input_dim, batch_size, d_model = k.size()

        # Initialize with the first token
        s_initial = torch.sum(q * k[0], dim=-1, keepdim=True)
        m = s_initial
        u = torch.ones(batch_size, 1).to(k.device)
        w = v[0]

        # Iteratively combine the remaining tokens
        for i in range(1, input_dim):
            m, u, w = self.combine(m, u, w, k[i], torch.exp(k[i] - m), v[i])

        return w / u

# Define positional encoding
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        return x + self.encoding[:, :x.size(1)].to(x.device)

# Define the full model
class ParallelAttentionClassifier(nn.Module):
    def __init__(self, d_model, num_classes):
        super(ParallelAttentionClassifier, self).__init__()
        self.d_model = d_model
        self.num_classes = num_classes

        # Positional encoding
        self.pos_encoder = PositionalEncoding(d_model)

        # Learnable query
        self.query = nn.Parameter(torch.randn(1, d_model))

        # Parallel Attention Scan
        self.attention_scan = ParallelAttentionScan()

        # Classification head
        self.fc = nn.Linear(d_model, num_classes)

    def forward(self, x):
        batch_size, seq_len, d_model = x.size()

        # Add positional encoding
        x = self.pos_encoder(x)

        # Expand query to match batch size
        query = self.query.expand(batch_size, -1)

        # Transpose x to match the expected input shape for ParallelAttentionScan
        x = x.transpose(0, 1)  # Shape: (seq_len, batch_size, d_model)

        # Compute attention output
        attention_output = self.attention_scan(query, x, x)  # Shape: (batch_size, d_model)

        # Pass through the classification head
        logits = self.fc(attention_output)  # Shape: (batch_size, num_classes)

        return logits

# Simulate a more structured dataset
class StructuredDataset(Dataset):
    def __init__(self, num_samples, seq_len, d_model, num_classes):
        self.num_samples = num_samples
        self.seq_len = seq_len
        self.d_model = d_model
        self.num_classes = num_classes

        # Simulate structured data (e.g., MFCCs)
        self.data = torch.randn(num_samples, seq_len, d_model)  # Random audio frames
        self.labels = torch.randint(0, num_classes, (num_samples,))  # Random labels

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Configuration
batch_size = 64
seq_len = 20  # Number of audio frames per sequence
d_model = 40  # Dimensionality of each audio frame (e.g., MFCCs)
num_classes = 5  # Number of classes for classification
num_samples = 1000  # Number of samples in the dataset
num_epochs = 50
learning_rate = 0.001

# Create dataset and dataloader
dataset = StructuredDataset(num_samples, seq_len, d_model, num_classes)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize model, loss function, and optimizer
model = ParallelAttentionClassifier(d_model, num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, labels) in enumerate(dataloader):
        # Forward pass
        logits = model(data)
        loss = criterion(logits, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute accuracy
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

        # Print statistics
        running_loss += loss.item()
        if batch_idx % 10 == 9:  # Print every 10 batches
            print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(dataloader)}], "
                  f"Loss: {running_loss / 10:.4f}, Accuracy: {100 * correct / total:.2f}%")
            running_loss = 0.0
            correct = 0
            total = 0

# Test the model
model.eval()
test_dataset = StructuredDataset(num_samples=200, seq_len=seq_len, d_model=d_model, num_classes=num_classes)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

correct = 0
total = 0
with torch.no_grad():
    for data, labels in test_dataloader:
        logits = model(data)
        _, predicted = torch.max(logits, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Epoch [1/50], Batch [10/16], Loss: 1.7429, Accuracy: 21.56%
Epoch [2/50], Batch [10/16], Loss: 1.6825, Accuracy: 22.19%
Epoch [3/50], Batch [10/16], Loss: 1.6673, Accuracy: 22.97%
Epoch [4/50], Batch [10/16], Loss: 1.6670, Accuracy: 20.31%
Epoch [5/50], Batch [10/16], Loss: 1.6576, Accuracy: 22.97%
Epoch [6/50], Batch [10/16], Loss: 1.6322, Accuracy: 23.12%
Epoch [7/50], Batch [10/16], Loss: 1.6390, Accuracy: 24.22%
Epoch [8/50], Batch [10/16], Loss: 1.6380, Accuracy: 24.69%
Epoch [9/50], Batch [10/16], Loss: 1.6154, Accuracy: 26.56%
Epoch [10/50], Batch [10/16], Loss: 1.5985, Accuracy: 28.28%
Epoch [11/50], Batch [10/16], Loss: 1.6093, Accuracy: 26.72%
Epoch [12/50], Batch [10/16], Loss: 1.5923, Accuracy: 27.81%
Epoch [13/50], Batch [10/16], Loss: 1.5914, Accuracy: 27.81%
Epoch [14/50], Batch [10/16], Loss: 1.5752, Accuracy: 27.97%
Epoch [15/50], Batch [10/16], Loss: 1.5776, Accuracy: 26.72%
Epoch [16/50], Batch [10/16], Loss: 1.5691, Accuracy: 28.91%
Epoch [17/50], Batch [10/16], Los