Skip to content
Merged

Attn #138

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions mambular/arch_utils/mambattn_arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from .normalization_layers import RMSNorm
from .mamba_arch import ResidualBlock


class MambAttn(nn.Module):
"""Mamba model composed of alternating MambaBlocks and Attention layers.

Attributes:
config (MambaConfig): Configuration object for the Mamba model.
layers (nn.ModuleList): List of alternating ResidualBlock (Mamba layers) and attention layers constituting the model.
"""

def __init__(
self,
d_model=32,
n_layers=8,
n_attention_layers=1, # Introduce attention layer count
n_mamba_per_attention=1, # Ratio of Mamba layers to attention layers
n_heads=4, # Number of attention heads
expand_factor=2,
bias=False,
d_conv=8,
conv_bias=True,
dropout=0.0,
attn_dropout=0.1,
dt_rank="auto",
d_state=16,
dt_scale=1.0,
dt_init="random",
dt_max=0.1,
last_layer="attn", # Define the desired last layer type
dt_min=1e-03,
dt_init_floor=1e-04,
norm=RMSNorm,
activation=F.silu,
bidirectional=False,
use_learnable_interaction=False,
layer_norm_eps=1e-05,
AD_weight_decay=False,
BC_layer_norm=True,
):
super().__init__()

# Define Mamba and Attention layers alternation
self.layers = nn.ModuleList()

total_blocks = n_layers + n_attention_layers # Total blocks to be created
attention_count = 0

for i in range(total_blocks):
if (i + 1) % (
n_mamba_per_attention + 1
) == 0: # Insert attention layer after N Mamba layers
self.layers.append(
nn.MultiheadAttention(
embed_dim=d_model, num_heads=n_heads, dropout=attn_dropout
)
)
attention_count += 1
else:
self.layers.append(
ResidualBlock(
d_model,
expand_factor,
bias,
d_conv,
conv_bias,
dropout,
dt_rank,
d_state,
dt_scale,
dt_init,
dt_max,
dt_min,
dt_init_floor,
norm,
activation,
bidirectional,
use_learnable_interaction,
layer_norm_eps,
AD_weight_decay,
BC_layer_norm,
)
)

# Check the type of the last layer and append the desired one if necessary
if last_layer == "attn":
if not isinstance(self.layers[-1], nn.MultiheadAttention):
self.layers.append(
nn.MultiheadAttention(
embed_dim=d_model, num_heads=n_heads, dropout=dropout
)
)
else:
if not isinstance(self.layers[-1], ResidualBlock):
self.layers.append(
ResidualBlock(
d_model,
expand_factor,
bias,
d_conv,
conv_bias,
dropout,
dt_rank,
d_state,
dt_scale,
dt_init,
dt_max,
dt_min,
dt_init_floor,
norm,
activation,
bidirectional,
use_learnable_interaction,
layer_norm_eps,
AD_weight_decay,
BC_layer_norm,
)
)

def forward(self, x):
for layer in self.layers:
if isinstance(layer, nn.MultiheadAttention):
# If it's an attention layer, handle input shape (seq_len, batch, embed_dim)
x = x.transpose(
0, 1
) # Switch to (seq_len, batch, embed_dim) for attention
x, _ = layer(x, x, x)
x = x.transpose(0, 1) # Switch back to (batch, seq_len, embed_dim)
else:
# Otherwise, pass through Mamba block
x = layer(x)

return x
141 changes: 141 additions & 0 deletions mambular/arch_utils/rnn_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import torch
import torch.nn as nn


class ConvRNN(nn.Module):
def __init__(
self,
model_type: str, # 'RNN', 'LSTM', or 'GRU'
input_size: int, # Number of input features (128 in your case)
hidden_size: int, # Number of hidden units in RNN layers
num_layers: int, # Number of RNN layers
bidirectional: bool, # Whether RNN is bidirectional
rnn_dropout: float, # Dropout rate for RNN
bias: bool, # Bias for RNN
conv_bias: bool, # Bias for Conv1d
rnn_activation: str = None, # Only for RNN
d_conv: int = 4, # Kernel size for Conv1d
residuals: bool = False, # Whether to use residual connections
):
super(ConvRNN, self).__init__()

# Choose RNN layer based on model_type
rnn_layer = {"RNN": nn.RNN, "LSTM": nn.LSTM, "GRU": nn.GRU}[model_type]

self.input_size = input_size # Number of input features (128 in your case)
self.hidden_size = hidden_size # Number of hidden units in RNN
self.num_layers = num_layers # Number of RNN layers
self.bidirectional = bidirectional # Whether RNN is bidirectional
self.rnn_type = model_type
self.residuals = residuals

# Convolutional layers
self.convs = nn.ModuleList()

if self.residuals:
self.residual_matrix = nn.ParameterList(
[
nn.Parameter(torch.randn(hidden_size, hidden_size))
for _ in range(num_layers)
]
)

# First Conv1d layer uses input_size
self.convs.append(
nn.Conv1d(
in_channels=self.input_size, # Input size for first layer
out_channels=self.input_size, # Output channels (128)
kernel_size=d_conv,
padding=d_conv - 1, # Padding to maintain sequence length
bias=conv_bias,
groups=self.input_size, # Depthwise convolution, each channel independent
)
)

# Subsequent Conv1d layers use hidden_size as input
for i in range(self.num_layers - 1):
self.convs.append(
nn.Conv1d(
in_channels=self.hidden_size, # Hidden size for subsequent layers
out_channels=self.hidden_size, # Output channels
kernel_size=d_conv,
padding=d_conv - 1, # Padding to maintain sequence length
bias=conv_bias,
groups=self.hidden_size, # Depthwise convolution
)
)

# Initialize the RNN layers
self.rnns = nn.ModuleList()
for i in range(self.num_layers):
if model_type == "RNN":
rnn = rnn_layer(
input_size=(
self.input_size if i == 0 else self.hidden_size
), # First layer uses input_size
hidden_size=self.hidden_size,
num_layers=1, # One RNN layer at a time
bidirectional=self.bidirectional,
batch_first=True,
dropout=rnn_dropout if i < self.num_layers - 1 else 0,
bias=bias,
nonlinearity=(
rnn_activation if model_type == "RNN" else None
), # Only RNN uses nonlinearity
)
else: # For LSTM or GRU
rnn = rnn_layer(
input_size=(
self.input_size if i == 0 else self.hidden_size
), # First layer uses input_size
hidden_size=self.hidden_size,
num_layers=1, # One RNN layer at a time
bidirectional=self.bidirectional,
batch_first=True,
dropout=rnn_dropout if i < self.num_layers - 1 else 0,
bias=bias,
)
self.rnns.append(rnn)

def forward(self, x):
"""
Forward pass through Conv-RNN layers.

Parameters
-----------
x : torch.Tensor
Input tensor of shape (batch_size, seq_length, input_size).

Returns
--------
output : torch.Tensor
Output tensor after passing through Conv-RNN layers.
"""
_, L, _ = x.shape
if self.residuals:
residual = x

# Loop through the RNN layers and apply 1D convolution before each
for i in range(self.num_layers):
# Transpose to (batch_size, input_size, seq_length) for Conv1d
x = x.transpose(1, 2)

# Apply the 1D convolution
x = self.convs[i](x)[:, :, :L]

# Transpose back to (batch_size, seq_length, input_size)
x = x.transpose(1, 2)

# Pass through the RNN layer
x, _ = self.rnns[i](x)

# Residual connection with learnable matrix
if self.residuals:
if i < self.num_layers and i > 0:
residual_proj = torch.matmul(residual, self.residual_matrix[i])
x = x + residual_proj

# Update residual for next layer
residual = x

return x, _
2 changes: 2 additions & 0 deletions mambular/base_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .resnet import ResNet
from .tabtransformer import TabTransformer
from .mambatab import MambaTab
from .mambattn import MambAttn

__all__ = [
"TaskModel",
Expand All @@ -16,4 +17,5 @@
"MLP",
"BaseModel",
"MambaTab",
"MambAttn",
]
Loading