In [1]:
import torch
import torch.nn as nn
from mamba_ssm import Mamba
from einops import rearrange

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
D_MODEL_DEFAULT = 256
N_LAYERS_DEFAULT = 8
N_CHANNELS_DEFAULT = 129
PATCH_SIZE_DEFAULT = 10

D_STATE_DEFAULT = 16
EXPAND_DEFAULT = 2
D_CONV_DEFAULT = 4


VICREG_LAMBDA_DEFAULT = 25.0
VICREG_MU_DEFAULT = 25.0
VICREG_NU_DEFAULT = 1.0

MDN_COMPONENTS_DEFAULT = 5

In [3]:
# --- Patch Embedding Layer ---
class PatchEmbed(nn.Module):
    """
    EEG Patch Embedding.
    Takes (Batch, Channels, Time) -> (Batch, NumPatches, EmbedDim)
    """
    def __init__(self, n_channels=N_CHANNELS_DEFAULT, 
                        embed_dim=D_MODEL_DEFAULT, 
                        patch_size=PATCH_SIZE_DEFAULT):
        super().__init__()
        self.proj = nn.Conv1d(
            n_channels,
            embed_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

    def forward(self, x):
        x = self.proj(x)  # (B, D_MODEL, NumPatches)
        x = x.permute(0, 2, 1)  # (B, NumPatches, D_MODEL)
        return x

In [8]:
# ...existing code...
class NeuroBiMambaBlock(nn.Module):
    """
    Bi-directional Mamba block with a depthwise 1D conv front-end and a gated MLP-style
    projection. Input/Output shape: (B, L, d_model).

    Workflow:
      1. LayerNorm -> linear projection that produces two halves:
         - conv_input (for depthwise conv + activation)
         - gate_tensor (used to gate Mamba output)
      2. Depthwise conv applied along the sequence dimension (causally cropped).
      3. Run Mamba forward on the activated conv output and also on the reversed
         sequence to get a backward context; concatenate them.
      4. Gate the concatenated mamba output, project back to d_model and add residual.
    """
    def __init__(self,
                 d_model: int = D_MODEL_DEFAULT,
                 d_state: int = D_STATE_DEFAULT,
                 expand: int = EXPAND_DEFAULT,
                 d_conv: int = D_CONV_DEFAULT):
        super().__init__()
        # Hidden dimension after expansion (used for conv & mamba)
        self.hidden_dim = d_model * expand

        # Project input -> [conv_input | gate_residual]  (shape: 2 * hidden_dim)
        self.in_proj = nn.Linear(d_model, 2 * self.hidden_dim, bias=False)

        # Depthwise 1D conv: expects (B, hidden_dim, L)
        self.conv1d = nn.Conv1d(
            in_channels=self.hidden_dim,
            out_channels=self.hidden_dim,
            kernel_size=d_conv,
            padding=d_conv - 1,  # causal-style padding, we'll crop to original length
            groups=self.hidden_dim,
            bias=True,
        )

        self.activation = nn.SiLU()

        # Two Mamba blocks: one for forward context, one for backward (via flip)
        self.mamba_fwd = Mamba(d_model=self.hidden_dim, d_state=d_state, 
                               d_conv=d_conv, expand=1)
        self.mamba_bwd = Mamba(d_model=self.hidden_dim, d_state=d_state, 
                               d_conv=d_conv, expand=1)

        # Project concatenated (fwd | bwd) -> d_model
        self.out_proj = nn.Linear(2 * self.hidden_dim, d_model, bias=False)

        self.norm = nn.LayerNorm(d_model)

    def _split_projection(self, x: torch.Tensor):
        # x: (B, L, d_model) -> x_proj: (B, L, 2*hidden_dim)
        x_proj = self.in_proj(x)
        conv_input, gate = x_proj.chunk(2, dim=-1)  # each (B, L, hidden_dim)
        return conv_input, gate

    def _conv_activate(self, conv_input: torch.Tensor, seq_len: int):
        # conv_input: (B, L, hidden_dim) -> conv expects (B, hidden_dim, L)
        y = rearrange(conv_input, "b l d -> b d l")
        # conv1d with padding may extend length; crop to original sequence length
        y = self.conv1d(y)[:, :, :seq_len]
        y = rearrange(y, "b d l -> b l d")
        return self.activation(y)  # (B, L, hidden_dim)

    def _run_bi_mamba(self, activated: torch.Tensor):
        # activated: (B, L, hidden_dim)
        fwd = self.mamba_fwd(activated)  # (B, L, hidden_dim)
        # run backward by flipping sequence dimension
        bwd_in = torch.flip(activated, dims=[1])
        bwd_out = self.mamba_bwd(bwd_in)
        bwd = torch.flip(bwd_out, dims=[1])  # restore original order
        # concat along feature dim -> (B, L, 2*hidden_dim)
        return torch.cat([fwd, bwd], dim=-1)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: (B, L, d_model)
        residual = x
        x = self.norm(x)

        # 1) linear proj -> conv input + gate tensor
        conv_input, gate = self._split_projection(x)

        # 2) depthwise conv + activation (preserve seq length)
        conv_activated = self._conv_activate(conv_input, seq_len=x.shape[1])

        # 3) bi-directional Mamba processing
        mamba_out = self._run_bi_mamba(conv_activated)  # (B, L, 2*hidden_dim)

        # 4) SPlit mamba_out, apply same gate to both parts
        # gated output and projection back to d_model
        fwd_out, bwd_out = mamba_out.chunk(2, dim=-1)  # each (B, L, hidden_dim)
        gate_activation = self.activation(gate)
        gated_fwd = fwd_out * gate_activation
        gated_bwd = bwd_out * gate_activation
        gated = torch.cat([gated_fwd, gated_bwd], dim=-1)  # (B, L, 2*hidden_dim)

        out = self.out_proj(gated)  # (B, L, d_model)

        # residual connection
        return out + residual

JEPA Backbone

In [9]:
# ...existing code...
class EegMambaJEPA(nn.Module):
    """
    JEPA-style backbone using Patch embedding + stacked NeuroBiMambaBlock layers.

    Input: (B, C, T)
    Output: CLS token embedding -> (B, d_model)
    """
    def __init__(
        self,
        d_model: int = D_MODEL_DEFAULT,
        n_layer: int = N_LAYERS_DEFAULT,
        n_channels: int = N_CHANNELS_DEFAULT,
        patch_size: int = PATCH_SIZE_DEFAULT,
        d_state: int = D_STATE_DEFAULT,
        expand: int = EXPAND_DEFAULT,
        use_pos_embed: bool = False,
        max_len: int = 1000,
    ):
        super().__init__()
        self.d_model = d_model
        self.use_pos_embed = use_pos_embed

        # Patch embedding: (B, C, T) -> (B, NumPatches, d_model)
        self.patch_embed = PatchEmbed(n_channels, d_model, patch_size)

        # Learnable CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        # Optional positional embeddings (applied after adding CLS)
        if self.use_pos_embed:
            self.pos_embed = nn.Parameter(torch.randn(1, max_len + 1, d_model))

        # Stack of NeuroBiMamba blocks
        self.mamba_blocks = self._build_mamba_stack(n_layer, d_model, d_state, expand)

        # Final layer norm
        self.norm_f = nn.LayerNorm(d_model)

    def _build_mamba_stack(self, n_layer: int, d_model: int, d_state: int, expand: int) -> nn.Sequential:
        """Create a sequential stack of NeuroBiMambaBlock modules."""
        blocks = [
            NeuroBiMambaBlock(d_model=d_model, d_state=d_state, expand=expand)
            for _ in range(n_layer)
        ]
        return nn.Sequential(*blocks)

    def _prepend_cls_token(self, x: torch.Tensor) -> torch.Tensor:
        """Prepend the CLS token to a batch of patch embeddings."""
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, d_model)
        return torch.cat((cls_tokens, x), dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: (B, C, T)
        returns: (B, d_model)  -- embedding for CLS token
        """
        # Embed patches
        x = self.patch_embed(x)  # (B, NumPatches, d_model)

        # Prepend CLS token
        x = self._prepend_cls_token(x)  # (B, 1 + NumPatches, d_model)

        # Optional positional embeddings (truncate to current length)
        if self.use_pos_embed:
            seq_len = x.shape[1]
            x = x + self.pos_embed[:, :seq_len]

        # Pass through Mamba blocks and final norm
        x = self.mamba_blocks(x)
        x = self.norm_f(x)

        # Return CLS representation
        return x[:, 0]
# ...existing code...

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
x = torch.randn(8, 129, 1000).to(device)  # Example input: (B, C, T)
model = EegMambaJEPA(
    d_model = 256,
    n_layer = 8,
    n_channels = 129,
    patch_size = 10,
    d_state = 256,
    expand = 4,
    use_pos_embed= False,
    max_len = 1000,

).to(device)
out = model(x)
print(out.shape)

torch.Size([8, 256])


In [6]:
class VICReg(nn.Module):
    def __init__(self, d_model = D_MODEL_DEFAULT, 
                       lambda_val = VICREG_LAMBDA_DEFAULT, 
                       mu_val = VICREG_MU_DEFAULT,
                       nu_val = VICREG_NU_DEFAULT,  
                       eps=1e-4):
        
        super().__init__()

        self.lambda_val = lambda_val
        self.mu_val = mu_val
        self.nu_val = nu_val
        self.eps = eps

        # Projector used during pre-training
        self.projector = nn.Sequential(
            nn.Linear(d_model, d_model), nn.ReLU(),
            nn.Linear(d_model, d_model), nn.ReLU(),
            nn.Linear(d_model, d_model)
        )
        
    def forward(self, z1, z2):
        z1p = self.projector(z1)
        z2p = self.projector(z2)

        # Invariance term (MSE)
        repr_loss = nn.functional.mse_loss(z1p, z2p)

        # Variance term (Hinge Loss on std dev)
        z1_norm = z1p - z1p.mean(dim = 0)
        z2_norm = z2p - z2p.mean(dim = 0)
        std_z1 = torch.sqrt(z1_norm.var(dim=0) + self.eps) 
        std_z2 = torch.sqrt(z2_norm.var(dim=0) + self.eps)
        std_loss = (torch.mean(nn.functional.relu(1 - std_z1)) + \
                     torch.mean(nn.functional.relu(1 - std_z2))) / 2

        # Covariance term (L2 norm of off-diagonal elements)
        B, D = z1_norm.shape
        cov_z1 = (z1_norm.T @ z1_norm) / (B - 1)
        cov_z2 = (z2_norm.T @ z2_norm) / (B - 1)

        # Create a mask for off-diagonal elements
        off_diag_mask = ~torch.eye(D, device=z1_norm.device).bool()

        # Sum squared off-diagonal elements and normalize by dimension
        cov_loss = (cov_z1[off_diag_mask].pow(2).sum() / D +
                    cov_z2[off_diag_mask].pow(2).sum() / D) / 2

        # Combine losses
        loss = (self.lambda_val * repr_loss + 
                self.mu_val * std_loss +
                self.nu_val * cov_loss)
        
        return loss


In [7]:
vicreg = VICReg(d_model = 256)

Testing why Pushing Negative Away reduces the representation collapse in self-supervised learning

In [8]:
import torch.nn.functional as F

# Real features: similar within pair, different across
z1a = torch.tensor([0.9, 0.1, -0.2]).unsqueeze(0)  # Dog-like
z1b = torch.tensor([0.8, 0.2, -0.1]).unsqueeze(0)  # Augmented dog
z2a = torch.tensor([-0.3, 0.9, 0.4]).unsqueeze(0)  # Cat-like
z2b = torch.tensor([-0.2, 0.8, 0.5]).unsqueeze(0)  # Augmented cat

# # Normalize for cosine similarity
z1a = F.normalize(z1a)
z1b = F.normalize(z1b)
z2a = F.normalize(z2a)
z2b = F.normalize(z2b)

# Similarities
sim_pos = F.cosine_similarity(z1a, z1b)  # ~0.95 (high)
sim_neg1 = F.cosine_similarity(z1a, z2a) # ~-0.2 (low)
sim_neg2 = F.cosine_similarity(z1a, z2b) # ~-0.1 (low)

print(f"Positive sim: {sim_pos.item():.3f}")
print(f"Negative sim: {sim_neg1.item():.3f}, {sim_neg2.item():.3f}")

# Loss calculation (tau=0.5)
num = torch.exp(sim_pos / 0.5)     # exp(0.95/0.5) = exp(1.9) ≈ 6.69
den = num + torch.exp(sim_neg1/0.5) + torch.exp(sim_neg2/0.5)  # ~6.69 + 0.74 + 0.82 ≈ 8.25
loss = -torch.log(num / den)       # -log(6.69/8.25) ≈ 0.21

print(f"Loss: {loss.item():.3f}")  # Much lower than collapsed case!


Positive sim: 0.987
Negative sim: -0.272, -0.224
Loss: 0.157


In [9]:
# --- MDN Head & Loss (Loss function included for reference) ---
class MDNHead(nn.Module):
    """
    Mixture Density Network head that predicts parameters of a 1-D Gaussian
    mixture for each input vector.

    Inputs:
      x: Tensor of shape (B, input_dim)

    Outputs (each tensor shape (B, n_components)):
      pi    : mixture weights (probabilities that sum to 1 across components)
      sigma : positive standard deviations (softplus output + eps)
      mu    : component means
    """
    def __init__(self, 
                 input_dim: int = D_MODEL_DEFAULT,  
                 n_components: int = MDN_COMPONENTS_DEFAULT,
                 min_sigma: float = 1e-6):
        
        super().__init__()
        self.n_components = n_components
        self.min_sigma = min_sigma


        # Separate linear heads for mixture logits, means and (unconstrained) scale
        self.pi = nn.Linear(input_dim, n_components)      
        self.sigma = nn.Linear(input_dim, n_components)  
        self.mu = nn.Linear(input_dim, n_components)       

        # Softplus is a smooth positive function for sigma stability
        self.softplus = nn.Softplus()

    def forward(self, x):
        """
        Predict MDN parameters 

        Returns:
            pi    : mixture weights (B, n_components)
            sigma : standard deviations (B, n_components)
            mu    : component means (B, n_components)
        """

        # Mixture logits -> probabilities
        pi_logits = self.pi(x)
        pi = torch.softmax(pi_logits, dim=1)

        # Means
        mu  = self.mu(x)

        # Positive scales via softflus for numerical stability
        sigma = self.softplus(self.sigma(x)) + self.min_sigma

        return pi, sigma, mu
