In [1]:
# Debug the final einsum issue - local paths
import sys
import os
sys.path.append('/Users/nolanrobbins/Desktop/VS Code Projects/MovieLens-RecSys')

import torch
import torch.nn as nn

# Debug both einsum operations with correct tensor dimensions
batch_size = 1024
d_model = 64
d_state = 16

# Create the tensors involved in the failing einsums
states = torch.randn(batch_size, d_state)  # [1024, 16] 
u_t = torch.randn(batch_size, d_model)     # [1024, 64]  
dB_t = torch.randn(batch_size, d_state, d_model)  # [1024, 16, 64]
C = torch.randn(d_state, d_model)          # [16, 64]

print("=== TENSOR SHAPES ===")
print(f"states.shape: {states.shape}")     # [1024, 16]
print(f"u_t.shape: {u_t.shape}")           # [1024, 64]  
print(f"dB_t.shape: {dB_t.shape}")         # [1024, 16, 64]
print(f"C.shape: {C.shape}")               # [16, 64]

print("\n=== EINSUM 1: 'bsd,bd->bs' ===")
print("dB[:, t, :, :] should be [batch_size, d_state, d_model]")
print("u_t should be [batch_size, d_model]") 
print("Result should be [batch_size, d_state]")
try:
    result1 = torch.einsum('bsd,bd->bs', dB_t, u_t)
    print(f"✅ First einsum succeeded: {result1.shape}")
except Exception as e:
    print(f"❌ First einsum failed: {e}")

print("\n=== EINSUM 2: 'ds,bs->bd' ===")  
print("C should be [d_state, d_model]")
print("states should be [batch_size, d_state]")
print("Result should be [batch_size, d_model]")
try:
    result2 = torch.einsum('ds,bs->bd', C, states)
    print(f"✅ Second einsum succeeded: {result2.shape}")
except Exception as e:
    print(f"❌ Second einsum failed: {e}")

# The issue might be that the einsum equation is wrong. Let's check what should happen:
print("\n=== ANALYZING EINSUM 2 ===")
print("Current equation: 'ds,bs->bd'")
print("  d=d_state(16), s=d_state(16), b=batch_size(1024)")  
print("  This expects C[d_state, d_state] but we have C[d_state, d_model]")
print("  The equation should probably be: 'dm,bm->bd' (not 'ds,bs->bd')")

print(f"\nTesting corrected equation 'dm,bm->bd':")
try:
    # This should be the correct einsum for C[d_state, d_model] * states[batch, d_state] 
    # But wait, that doesn't make sense either...
    
    # Let's think: if we want [batch_size, d_model] output
    # And we have C[d_state, d_model] and states[batch_size, d_state]
    # Then we want: states @ C = [batch_size, d_state] @ [d_state, d_model] = [batch_size, d_model]
    result2_corrected = torch.einsum('bs,sd->bd', states, C)
    print(f"✅ Corrected einsum 'bs,sd->bd' succeeded: {result2_corrected.shape}")
except Exception as e:
    print(f"❌ Corrected einsum failed: {e}")
    
print(f"\nAlternative using matrix multiplication:")
try:
    result_matmul = states @ C  # [1024, 16] @ [16, 64] = [1024, 64]
    print(f"✅ Matrix multiplication succeeded: {result_matmul.shape}")
except Exception as e:
    print(f"❌ Matrix multiplication failed: {e}")

=== TENSOR SHAPES ===
states.shape: torch.Size([1024, 16])
u_t.shape: torch.Size([1024, 64])
dB_t.shape: torch.Size([1024, 16, 64])
C.shape: torch.Size([16, 64])

=== EINSUM 1: 'bsd,bd->bs' ===
dB[:, t, :, :] should be [batch_size, d_state, d_model]
u_t should be [batch_size, d_model]
Result should be [batch_size, d_state]
✅ First einsum succeeded: torch.Size([1024, 16])

=== EINSUM 2: 'ds,bs->bd' ===
C should be [d_state, d_model]
states should be [batch_size, d_state]
Result should be [batch_size, d_model]
❌ Second einsum failed: einsum(): subscript s has size 16 for operand 1 which does not broadcast with previously seen size 64

=== ANALYZING EINSUM 2 ===
Current equation: 'ds,bs->bd'
  d=d_state(16), s=d_state(16), b=batch_size(1024)
  This expects C[d_state, d_state] but we have C[d_state, d_model]
  The equation should probably be: 'dm,bm->bd' (not 'ds,bs->bd')

Testing corrected equation 'dm,bm->bd':
✅ Corrected einsum 'bs,sd->bd' succeeded: torch.Size([1024, 64])

Alternative 

In [2]:
import sys

sys.path.insert(0, '/MovieLens-RecSys/.venv/lib/python3.11/site-packages')
sys.path.append('/MovieLens-RecSys')

import torch
import pandas as pd
from pathlib import Path

# Now we can debug the tensor shapes and imports
print("Setup complete!")

Setup complete!


In [3]:
# Import S5Layer with correct local paths
import sys
import os
sys.path.append('/Users/nolanrobbins/Desktop/VS Code Projects/MovieLens-RecSys')

from models.sota_2025.components.state_space_models import S5Layer

print("S5Layer imported successfully!")

S5Layer imported successfully!


In [4]:
# Create tensors with the same dimensions as your training
batch_size = 1024
seq_len = 200
d_model = 64  # Typical embedding dimension

# Simulate the inputs that are causing the problem
x = torch.randn(batch_size, seq_len, d_model)
time_intervals = torch.randn(batch_size, seq_len - 1)  # This is likely the issue - seq_len-1

print(f"x.shape: {x.shape}")
print(f"time_intervals.shape: {time_intervals.shape}")
print(f"time_intervals expected shape: {(batch_size, seq_len)}")

x.shape: torch.Size([1024, 200, 64])
time_intervals.shape: torch.Size([1024, 199])
time_intervals expected shape: (1024, 200)


In [5]:
# Create an S5Layer instance
s5_layer = S5Layer(d_model=d_model, d_state=16)

# Try the forward pass that's failing
try:
    output = s5_layer(x, time_intervals)
    print("✅ S5Layer forward pass succeeded!")
    print(f"Output shape: {output.shape}")
except Exception as e:
    print(f"❌ Error in S5Layer: {e}")
    print("This is the exact error we need to fix!")

✅ S5Layer forward pass succeeded!
Output shape: torch.Size([1024, 200, 64])


In [6]:
# Let's see what happens in the S5Layer forward method step by step
d_model = 64
d_state = 16

# Simulate the problematic part
batch_size, current_len = time_intervals.shape
print(f"time_intervals.shape: {time_intervals.shape}")
print(f"batch_size: {batch_size}, current_len: {current_len}")

# Create dt tensor like in the model
dt = torch.randn(batch_size, 200, 1)  # This should match seq_len
expected_len = dt.shape[1]
print(f"dt.shape: {dt.shape}")
print(f"expected_len: {expected_len}")

# Test the padding logic from my fix
if current_len < expected_len:
    print(f"Need to pad: {current_len} -> {expected_len}")
    padding = time_intervals[:, -1:].expand(batch_size, expected_len -
current_len)
    print(f"padding.shape: {padding.shape}")
    time_intervals_padded = torch.cat([time_intervals, padding], dim=1)
    print(f"time_intervals_padded.shape: {time_intervals_padded.shape}")
else:
    time_intervals_padded = time_intervals

# Test the final operation
dt_squeezed = dt.squeeze(-1)
time_padded_unsqueezed = time_intervals_padded.unsqueeze(-1)

print(f"dt.squeeze(-1).shape: {dt_squeezed.shape}")
print(f"time_intervals_padded.unsqueeze(-1).shape: {time_padded_unsqueezed.shape}")

# This should work now
try:
    result = dt_squeezed * time_padded_unsqueezed
    print("✅ Tensor multiplication succeeded!")
    print(f"Result shape: {result.shape}")
except Exception as e:
    print(f"❌ Still failing: {e}")

time_intervals.shape: torch.Size([1024, 199])
batch_size: 1024, current_len: 199
dt.shape: torch.Size([1024, 200, 1])
expected_len: 200
Need to pad: 199 -> 200
padding.shape: torch.Size([1024, 1])
time_intervals_padded.shape: torch.Size([1024, 200])
dt.squeeze(-1).shape: torch.Size([1024, 200])
time_intervals_padded.unsqueeze(-1).shape: torch.Size([1024, 200, 1])
❌ Still failing: The size of tensor a (1024) must match the size of tensor b (200) at non-singleton dimension 1


In [7]:
# Let's step through the S5Layer forward method manually
import torch.nn as nn

class DebugS5Layer(nn.Module):
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state

        # Same initialization as S5Layer
        self.dt_proj = nn.Linear(d_model, 1)
        self.dt_min = 1e-3
        self.dt_max = 1e-1
        self.A_log = nn.Parameter(torch.randn(d_state))
        self.B = nn.Parameter(torch.randn(d_state, d_model))
        self.C = nn.Parameter(torch.randn(d_state, d_model))

    def forward(self, x, time_intervals=None):
        print(f"Input x.shape: {x.shape}")
        print(f"Input time_intervals.shape: {time_intervals.shape if time_intervals is not None else None}")

        # Compute dt
        dt = self.dt_proj(x)  # [batch_size, seq_len, 1]
        dt = torch.sigmoid(dt) * (self.dt_max - self.dt_min) + self.dt_min
        print(f"dt.shape after projection: {dt.shape}")

        # Debug the time_intervals handling
        if time_intervals is not None:
            batch_size, current_len = time_intervals.shape
            expected_len = dt.shape[1]
            print(f"time_intervals: batch_size={batch_size}, current_len={current_len}")
            print(f"dt: expected_len={expected_len}")

            if current_len < expected_len:
                padding = time_intervals[:, -1:].expand(batch_size, expected_len - current_len)
                time_intervals_padded = torch.cat([time_intervals, padding], dim=1)
                print(f"After padding: {time_intervals_padded.shape}")
            else:
                time_intervals_padded = time_intervals
                print(f"No padding needed: {time_intervals_padded.shape}")

            print(f"Before multiplication:")
            print(f"  dt.squeeze(-1).shape: {dt.squeeze(-1).shape}")
            print(f"  time_intervals_padded.unsqueeze(-1).shape: {time_intervals_padded.unsqueeze(-1).shape}")

            # This is the line that's failing
            dt = dt.squeeze(-1) * time_intervals_padded.unsqueeze(-1)
            print(f"After multiplication: dt.shape = {dt.shape}")

        return x  # Just return x for debugging

# Test with debug layer
debug_s5 = DebugS5Layer(d_model=64, d_state=16)
try:
    output = debug_s5(x, time_intervals)
    print("✅ Debug layer succeeded!")
except Exception as e:
    print(f"❌ Debug layer failed: {e}")

Input x.shape: torch.Size([1024, 200, 64])
Input time_intervals.shape: torch.Size([1024, 199])
dt.shape after projection: torch.Size([1024, 200, 1])
time_intervals: batch_size=1024, current_len=199
dt: expected_len=200
After padding: torch.Size([1024, 200])
Before multiplication:
  dt.squeeze(-1).shape: torch.Size([1024, 200])
  time_intervals_padded.unsqueeze(-1).shape: torch.Size([1024, 200, 1])
❌ Debug layer failed: The size of tensor a (1024) must match the size of tensor b (200) at non-singleton dimension 1


In [8]:
# Let's examine the exact tensors and their properties
dt_squeezed = dt.squeeze(-1)
time_padded_unsqueezed = time_intervals_padded.unsqueeze(-1)

print("=== TENSOR ANALYSIS ===")
print(f"dt_squeezed.shape: {dt_squeezed.shape}")
print(f"dt_squeezed.stride(): {dt_squeezed.stride()}")
print(f"dt_squeezed.is_contiguous(): {dt_squeezed.is_contiguous()}")

print(f"\ntime_padded_unsqueezed.shape: {time_padded_unsqueezed.shape}")
print(f"time_padded_unsqueezed.stride(): {time_padded_unsqueezed.stride()}")
print(f"time_padded_unsqueezed.is_contiguous(): {time_padded_unsqueezed.is_contiguous()}")

print("\n=== ATTEMPTING DIFFERENT APPROACHES ===")

# Approach 1: Make tensors contiguous
try:
    result1 = dt_squeezed.contiguous() * time_padded_unsqueezed.contiguous()
    print("✅ Approach 1 (contiguous) succeeded!")
except Exception as e:
    print(f"❌ Approach 1 failed: {e}")

# Approach 2: Different dimension handling
try:
    # Instead of squeeze/unsqueeze, let's try different reshaping
    dt_reshaped = dt.view(1024, 200)  # [batch, seq_len]
    time_reshaped = time_intervals_padded  # Already [batch, seq_len]
    result2 = dt_reshaped * time_reshaped
    print("✅ Approach 2 (reshape) succeeded!")
    print(f"Result2 shape: {result2.shape}")
except Exception as e:
    print(f"❌ Approach 2 failed: {e}")

# Approach 3: Element-wise with explicit broadcasting
try:
    dt_for_broadcast = dt.squeeze(-1)  # [1024, 200]
    time_for_broadcast = time_intervals_padded  # [1024, 200]
    print(f"Broadcasting shapes: {dt_for_broadcast.shape} * {time_for_broadcast.shape}")
    result3 = torch.mul(dt_for_broadcast, time_for_broadcast)
    print("✅ Approach 3 (torch.mul) succeeded!")
    print(f"Result3 shape: {result3.shape}")
except Exception as e:
    print(f"❌ Approach 3 failed: {e}")

=== TENSOR ANALYSIS ===
dt_squeezed.shape: torch.Size([1024, 200])
dt_squeezed.stride(): (200, 1)
dt_squeezed.is_contiguous(): True

time_padded_unsqueezed.shape: torch.Size([1024, 200, 1])
time_padded_unsqueezed.stride(): (200, 1, 1)
time_padded_unsqueezed.is_contiguous(): True

=== ATTEMPTING DIFFERENT APPROACHES ===
❌ Approach 1 failed: The size of tensor a (1024) must match the size of tensor b (200) at non-singleton dimension 1
✅ Approach 2 (reshape) succeeded!
Result2 shape: torch.Size([1024, 200])
Broadcasting shapes: torch.Size([1024, 200]) * torch.Size([1024, 200])
✅ Approach 3 (torch.mul) succeeded!
Result3 shape: torch.Size([1024, 200])


In [9]:
# Test the fixed approach in the S5Layer
class FixedS5Layer(nn.Module):
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.dt_proj = nn.Linear(d_model, 1)
        self.dt_min = 1e-3
        self.dt_max = 1e-1
        self.A_log = nn.Parameter(torch.randn(d_state))
        self.B = nn.Parameter(torch.randn(d_state, d_model))
        self.C = nn.Parameter(torch.randn(d_state, d_model))

    def forward(self, x, time_intervals=None):
        dt = self.dt_proj(x)
        dt = torch.sigmoid(dt) * (self.dt_max - self.dt_min) + self.dt_min

        if time_intervals is not None:
            batch_size, current_len = time_intervals.shape
            expected_len = dt.shape[1]

            if current_len < expected_len:
                padding = time_intervals[:, -1:].expand(batch_size, expected_len - current_len)
                time_intervals_padded = torch.cat([time_intervals, padding], dim=1)
            else:
                time_intervals_padded = time_intervals

            # FIXED LINE: Use view instead of squeeze/unsqueeze
            dt = dt.view(batch_size, expected_len) * time_intervals_padded
            dt = dt.unsqueeze(-1)  # Add back the last dimension for later operations

        return x  # Return input for testing

# Test the fix
fixed_s5 = FixedS5Layer(d_model=64, d_state=16)
try:
    output = fixed_s5(x, time_intervals)
    print("✅ Fixed S5Layer succeeded!")
except Exception as e:
    print(f"❌ Fixed S5Layer failed: {e}")

✅ Fixed S5Layer succeeded!


In [10]:
# Test the new approach: dt * time_intervals_padded.unsqueeze(-1)
batch_size = 1024
seq_len = 200
d_model = 64

# Create the same tensors as before
x = torch.randn(batch_size, seq_len, d_model)
time_intervals = torch.randn(batch_size, seq_len - 1)  # [1024, 199]

# Simulate dt computation
dt = torch.randn(batch_size, seq_len, 1)  # [batch_size, seq_len, 1] - original shape

print(f"dt.shape: {dt.shape}")
print(f"time_intervals.shape: {time_intervals.shape}")

# Do the padding
batch_size, current_len = time_intervals.shape
expected_len = dt.shape[1]
padding = time_intervals[:, -1:].expand(batch_size, expected_len - current_len)
time_intervals_padded = torch.cat([time_intervals, padding], dim=1)

print(f"time_intervals_padded.shape: {time_intervals_padded.shape}")
print(f"time_intervals_padded.unsqueeze(-1).shape: {time_intervals_padded.unsqueeze(-1).shape}")

# Test the new multiplication approach
try:
    result = dt * time_intervals_padded.unsqueeze(-1)
    print(f"✅ New approach succeeded! Result shape: {result.shape}")
except Exception as e:
    print(f"❌ New approach failed: {e}")

dt.shape: torch.Size([1024, 200, 1])
time_intervals.shape: torch.Size([1024, 199])
time_intervals_padded.shape: torch.Size([1024, 200])
time_intervals_padded.unsqueeze(-1).shape: torch.Size([1024, 200, 1])
✅ New approach succeeded! Result shape: torch.Size([1024, 200, 1])


In [11]:
class FinalS5Layer(nn.Module):
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.dt_proj = nn.Linear(d_model, 1)
        self.dt_min = 1e-3
        self.dt_max = 1e-1
        self.A_log = nn.Parameter(torch.randn(d_state))
        self.B = nn.Parameter(torch.randn(d_state, d_model))
        self.C = nn.Parameter(torch.randn(d_state, d_model))

    def forward(self, x, time_intervals=None):
        batch_size, seq_len, _ = x.shape
        dt = self.dt_proj(x)
        dt = torch.sigmoid(dt) * (self.dt_max - self.dt_min) + self.dt_min

        if time_intervals is not None:
            batch_size, current_len = time_intervals.shape
            expected_len = dt.shape[1]

            if current_len < expected_len:
                padding = time_intervals[:, -1:].expand(batch_size, expected_len - current_len)
                time_intervals_padded = torch.cat([time_intervals, padding], dim=1)
            else:
                time_intervals_padded = time_intervals

            # FINAL FIX: Proper broadcasting maintaining tensor dimensions
            dt = dt * time_intervals_padded.unsqueeze(-1)

        # Test the downstream operations that were failing
        A = -torch.exp(self.A_log)
        print(f"dt.shape before dA/dB: {dt.shape}")
        dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))
        dB = (dt.unsqueeze(-1) * self.B.unsqueeze(0).unsqueeze(0))
        print(f"dA.shape: {dA.shape}")
        print(f"dB.shape: {dB.shape}")

        return x

# Test the final fix
final_s5 = FinalS5Layer(d_model=64, d_state=16)
try:
    output = final_s5(x, time_intervals)
    print("✅ Final S5Layer with all operations succeeded!")
except Exception as e:
    print(f"❌ Final S5Layer failed: {e}")

dt.shape before dA/dB: torch.Size([1024, 200, 1])
dA.shape: torch.Size([1024, 200, 1, 16])
dB.shape: torch.Size([1024, 200, 16, 64])
✅ Final S5Layer with all operations succeeded!


In [12]:
# Debug the states computation that's failing
batch_size = 1024
seq_len = 200
d_model = 64
d_state = 16

# Simulate the tensors that should be created
dt = torch.randn(batch_size, seq_len, 1)  # After our fix
A = torch.randn(d_state)
B = torch.randn(d_state, d_model)
C = torch.randn(d_state, d_model)

print("=== TENSOR SHAPES ===")
print(f"dt.shape: {dt.shape}")
print(f"A.shape: {A.shape}")
print(f"B.shape: {B.shape}")
print(f"C.shape: {C.shape}")

# Compute dA and dB like in the original code
dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))  # [batch_size, seq_len, d_state]
dB = (dt.unsqueeze(-1) * B.unsqueeze(0).unsqueeze(0))  # [batch_size, seq_len, d_state, d_model]

print(f"dA.shape: {dA.shape}")
print(f"dB.shape: {dB.shape}")

# Initialize states
states = torch.zeros(batch_size, d_state)
print(f"Initial states.shape: {states.shape}")

# Simulate the first iteration of the loop
t = 0
x = torch.randn(batch_size, seq_len, d_model)  # Input tensor
u_t = x[:, t, :]  # [batch_size, d_model]
print(f"u_t.shape: {u_t.shape}")

# This is the problematic line 125
try:
    new_states = dA[:, t, :] * states + torch.einsum('bsd,bd->bs', dB[:, t, :, :], u_t)
    print(f"✅ States update succeeded: {new_states.shape}")
    states = new_states
except Exception as e:
    print(f"❌ States update failed: {e}")

# This is the problematic line 128
try:
    y_t = torch.einsum('ds,bs->bd', C, states)
    print(f"✅ Output computation succeeded: {y_t.shape}")
except Exception as e:
    print(f"❌ Output computation failed: {e}")
    print(f"Expected: C.shape={C.shape}, states.shape={states.shape}")
    print(f"C dimensions: {len(C.shape)}, states dimensions: {len(states.shape)}")

=== TENSOR SHAPES ===
dt.shape: torch.Size([1024, 200, 1])
A.shape: torch.Size([16])
B.shape: torch.Size([16, 64])
C.shape: torch.Size([16, 64])
dA.shape: torch.Size([1024, 200, 1, 16])
dB.shape: torch.Size([1024, 200, 16, 64])
Initial states.shape: torch.Size([1024, 16])
u_t.shape: torch.Size([1024, 64])
✅ States update succeeded: torch.Size([1024, 1024, 16])
❌ Output computation failed: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (3) for operand 1 and no ellipsis was given
Expected: C.shape=torch.Size([16, 64]), states.shape=torch.Size([1024, 1024, 16])
C dimensions: 2, states dimensions: 3


In [13]:
# Debug the problematic multiplication
print("=== DEBUGGING dA MULTIPLICATION ===")
print(f"dA.shape: {dA.shape}")  # [1024, 200, 1, 16]
print(f"dA[:, t, :].shape: {dA[:, t, :].shape}")  # This is the problem!

# The issue: dA[:, t, :] should be [1024, 16], but it's [1024, 1, 16]
dA_t = dA[:, t, :]  # [1024, 1, 16] 
states = torch.zeros(batch_size, d_state)  # [1024, 16]

print(f"dA_t.shape: {dA_t.shape}")
print(f"states.shape: {states.shape}")

# This multiplication creates wrong shape due to broadcasting
result = dA_t * states
print(f"dA_t * states result shape: {result.shape}")  # This will be wrong!

# The fix: we need to squeeze the extra dimension
print("\n=== TESTING THE FIX ===")
dA_t_fixed = dA[:, t, :].squeeze(1)  # Remove the singleton dimension
print(f"dA_t_fixed.shape: {dA_t_fixed.shape}")

result_fixed = dA_t_fixed * states
print(f"dA_t_fixed * states shape: {result_fixed.shape}")

# Test the full computation with the fix
einsum_result = torch.einsum('bsd,bd->bs', dB[:, t, :, :], u_t)
print(f"einsum result shape: {einsum_result.shape}")

final_states = dA_t_fixed * states + einsum_result
print(f"Final states shape: {final_states.shape}")

=== DEBUGGING dA MULTIPLICATION ===
dA.shape: torch.Size([1024, 200, 1, 16])
dA[:, t, :].shape: torch.Size([1024, 1, 16])
dA_t.shape: torch.Size([1024, 1, 16])
states.shape: torch.Size([1024, 16])
dA_t * states result shape: torch.Size([1024, 1024, 16])

=== TESTING THE FIX ===
dA_t_fixed.shape: torch.Size([1024, 16])
dA_t_fixed * states shape: torch.Size([1024, 16])
einsum result shape: torch.Size([1024, 16])
Final states shape: torch.Size([1024, 16])


In [14]:
class CompleteS5Layer(nn.Module):
    def __init__(self, d_model, d_state):
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.dt_proj = nn.Linear(d_model, 1)
        self.dt_min = 1e-3
        self.dt_max = 1e-1
        self.A_log = nn.Parameter(torch.randn(d_state))
        self.B = nn.Parameter(torch.randn(d_state, d_model))
        self.C = nn.Parameter(torch.randn(d_state, d_model))
        self.D = nn.Parameter(torch.zeros(d_model))  # Add D parameter
        self.norm = nn.LayerNorm(d_model)

    def forward(self, x, time_intervals=None):
        batch_size, seq_len, d_model = x.shape

        # Compute dt
        dt = self.dt_proj(x)
        dt = torch.sigmoid(dt) * (self.dt_max - self.dt_min) + self.dt_min

        # Apply time intervals
        if time_intervals is not None:
            batch_size, current_len = time_intervals.shape
            expected_len = dt.shape[1]

            if current_len < expected_len:
                padding = time_intervals[:, -1:].expand(batch_size, expected_len - current_len)
                time_intervals_padded = torch.cat([time_intervals, padding], dim=1)
            else:
                time_intervals_padded = time_intervals

            dt = dt * time_intervals_padded.unsqueeze(-1)

        # State space matrices
        A = -torch.exp(self.A_log)
        dA = torch.exp(dt.unsqueeze(-1) * A.unsqueeze(0).unsqueeze(0))
        dB = (dt.unsqueeze(-1) * self.B.unsqueeze(0).unsqueeze(0))

        # State space computation
        states = torch.zeros(batch_size, self.d_state, device=x.device, dtype=x.dtype)
        outputs = []

        for t in range(seq_len):
            u_t = x[:, t, :]
            # FIXED: squeeze the singleton dimension
            states = dA[:, t, :].squeeze(1) * states + torch.einsum('bsd,bd->bs', dB[:, t, :, :], u_t)
            y_t = torch.einsum('ds,bs->bd', self.C, states) + self.D * u_t
            outputs.append(y_t)

        output = torch.stack(outputs, dim=1)
        return self.norm(output)

# Test with actual tensor shapes
complete_s5 = CompleteS5Layer(d_model=64, d_state=16)
try:
    output = complete_s5(x, time_intervals)
    print(f"✅ Complete S5Layer succeeded! Output shape: {output.shape}")
except Exception as e:
    print(f"❌ Complete S5Layer failed: {e}")

❌ Complete S5Layer failed: einsum(): subscript s has size 16 for operand 1 which does not broadcast with previously seen size 64


In [15]:
# Check what RecBole has for SS4Rec
try:
    from recbole.model.sequential_recommender import SASRec
    from recbole.config import Config
    from recbole.data import create_dataset, data_preparation
    print("✅ RecBole is available!")
    print("Available sequential models in RecBole:")

    # List some common sequential models in RecBole
    models = ['SASRec', 'BERT4Rec', 'GRU4Rec', 'NextItNet']
    for model in models:
        try:
            exec(f"from recbole.model.sequential_recommender import {model}")
            print(f"  ✅ {model}")
        except:
            print(f"  ❌ {model}")

except ImportError as e:
    print(f"❌ RecBole import failed: {e}")

❌ RecBole import failed: No module named 'recbole'
