Skip to content

LibratioAI/sessa

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Sessa: Selective State Space Attention

SESSA: Selective State Space Attention

About

Sessa is a decoder architecture that integrates self-attention into a recurrent feedback pathway.

It combines input-dependent attention routing with feedback-based recurrent aggregation, aiming to improve long-context information preservation and integration beyond standard Transformers and SSM-based models such as Mamba.

Architecture comparison

Installation

git clone https://github.com/LibratioAI/sessa.git sessa-repo
cd sessa-repo
python -m pip install -e .

Optional: FlashAttention

This repo supports FlashAttention if available.
Official FlashAttention repository: https://github.com/Dao-AILab/flash-attention

Install it separately according to the official instructions (installation depends on your CUDA/PyTorch setup). If FlashAttention is not installed, the code falls back to a reference attention implementation.

Important

python pytorch cuda

FlashAttention is disabled by default. Enable it with use_flash=True.

Quickstart

import torch
from sessa.layer import SessaLayer  

# ---- Core dimensions ----
B = 2          # batch size
T = 128        # sequence length 
D = 512        # model width

# ---- Mixer settings ----
use_flash = True    # enable FlashAttention path if available + CUDA
use_forward_rope = True   # toggle RoPE on the forward attention branch
n_heads = 8         # number of heads for forward attention
ln_eps = 1e-5             # LayerNorm epsilon
gamma_max = 0.999         # bounds feedback gain: |gamma_t| <= gamma_max < 1
# NOTE:
# - D must be divisible by n_heads
# - (D // n_heads) must be even 

# Optional: precompute masks up to this length 
max_len = 1024

layer = SessaLayer(
    D=D,
    n_heads=n_heads,
    max_len=max_len,
    ln_eps=ln_eps,
    use_flash=use_flash,
    use_forward_rope=use_forward_rope, 
    gamma_max=gamma_max,
)

# Input tokens: (B, T, D)
x = torch.randn(B, T, D)

# If you want FlashAttention to actually run:
# - move to CUDA
# NOTE: dtype can stay fp32; the mixer will cast q/k/v to bf16/fp16 for flash
# and cast the output back to the model dtype.
if use_flash and torch.cuda.is_available():
    x = x.cuda()
    layer = layer.cuda()

y = layer(x)  # (B, T, D)
print(y.shape)

License

Licensed under the Apache License, Version 2.0. See LICENSE.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages