In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
from tqdm import tqdm


def reshape_to_sequence(x):
    return x.reshape(3, 1024).permute(1, 0)  # (1024, 3)

def load_sequential_cifar10(batch_size=128, fraction=0.2):
    """Load CIFAR-10 in sequential format with optional subset"""
    # Transform to convert images to sequences
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(reshape_to_sequence)
    ])

    # Download and load datasets
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True, transform=transform)
    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform)

    if fraction < 1.0:
        train_size = len(trainset)
        subset_size = int(train_size * fraction)
        indices = torch.randperm(train_size)[:subset_size]
        trainset = torch.utils.data.Subset(trainset, indices)

    # Create dataloaders
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    return trainloader, testloader, 3  # 3 channels


class S4Layer(nn.Module):
    def __init__(self, d_model, N=64, l_max=1024, discretization='zoh'):
        """
        S4 layer implementation with support for different discretization schemes

        Args:
            d_model: Model dimension
            N: State dimension
            l_max: Maximum sequence length
            discretization: Discretization scheme ('zoh', 'bilinear', or 'generalized')
        """
        super().__init__()
        self.d_model = d_model
        self.N = N
        self.l_max = l_max
        self.discretization = discretization
        self.register_buffer('dummy', torch.zeros(1))

        # Initialize SSM parameters
        # A is the state matrix (N x N)
        A = self._init_hippo_matrix()
        self.register_buffer('A', A)

        # Learnable parameters
        # B is the input matrix (N x 1)
        self.B = nn.Parameter(torch.randn(self.N, 1))
        # C is the output matrix (1 x N)
        self.C = nn.Parameter(torch.randn(1, self.N))
        # D is the feedthrough term
        self.D = nn.Parameter(torch.zeros(1))

        # Timescale parameter (delta)
        self.log_step = nn.Parameter(torch.zeros(1))

        # Register buffers for discrete matrices
        self.register_buffer('A_disc', torch.zeros(self.N, self.N))
        self.register_buffer('B_disc', torch.zeros(self.N, 1))
        self.register_buffer('k', torch.zeros(self.l_max))

        # Initialize discrete-time matrices
        self._setup_discretization()

    def _init_hippo_matrix(self):
        """Initialize A using HiPPO Normal matrix"""
        device = self.dummy.device
        A = np.zeros((self.N, self.N))
        for i in range(self.N):
            for j in range(self.N):
                if i < j:
                    A[i, j] = 0
                elif i == j:
                    A[i, j] = -(2 * i + 1)
                else:
                    A[i, j] = -2 * np.sqrt((2 * i + 1) * (2 * j + 1))
        return torch.from_numpy(A).float().to(device)

    def _setup_discretization(self):
        """Set up discrete-time matrices based on the selected scheme"""
        # Get step size
        dt = torch.exp(self.log_step)
        device = self.dummy.device

        # Create identity matrix on the correct device
        I = torch.eye(self.N, device=device)

        if self.discretization == 'zoh':
            # Zero-order hold discretization (standard)
            self.A_disc = torch.matrix_exp(self.A * dt)
            self.B_disc = torch.linalg.solve(
                self.A,
                (self.A_disc - I).matmul(self.B)
            )

        elif self.discretization == 'bilinear':
            # Bilinear transform (Tustin's method)
            left = torch.linalg.inv(I - dt/2 * self.A)
            right = I + dt/2 * self.A
            self.A_disc = left @ right
            self.B_disc = left @ (dt * self.B)

        elif self.discretization == 'generalized':
            # Generalized bilinear transform
            alpha = 0.5  # Can be made learnable
            left = torch.linalg.inv(I - dt * (1-alpha) * self.A)
            right = I + dt * alpha * self.A
            self.A_disc = left @ right
            self.B_disc = left @ (dt * self.B)

        # Pre-compute kernel
        self._compute_kernel()

    def _compute_kernel(self):
        """Compute convolution kernel for fast inference"""
        L = self.l_max
        device = self.dummy.device

        # Create tensor on the correct device
        k = torch.zeros(L, dtype=torch.cfloat, device=device)

        # Direct computation method
        A_powers = self.A_disc.unsqueeze(0)  # [1, N, N]
        for i in range(L):
            if i > 0:
                A_powers = A_powers @ self.A_disc
            k[i] = (self.C @ A_powers.squeeze(0) @ self.B_disc).item()

        # Store as real kernel for simplicity (imaginary part should be small)
        self.k = torch.real(k)

    def forward(self, x):
        """
        Forward pass using convolution for efficiency
        x: [batch_size, seq_len, d_model]
        """
        batch, seq_len, _ = x.shape
        device = x.device

        if self.dummy.device != device:
            self.to(device)

        # Recalculate discretization if parameters changed
        if self.training:
            self._setup_discretization()

        # Process each feature dimension separately
        out = torch.zeros_like(x)

        # Causal convolution for each feature
        for i in range(self.d_model):
            # Extract feature i
            xi = x[:, :, i].unsqueeze(1)  # [batch, 1, seq_len]

            # Prepare kernel - ensure it's on the correct device
            k_padded = nn.functional.pad(
                self.k[:seq_len].flip(0).unsqueeze(0).unsqueeze(0),
                (0, seq_len-1)
            )

            # Causal convolution
            yi = nn.functional.conv1d(xi, k_padded, padding=seq_len-1)
            yi = yi[:, :, :seq_len]

            # Add D term (skip connection) and store result
            out[:, :, i] = yi.squeeze(1) + self.D * x[:, :, i]

        return out


class S4Block(nn.Module):
    """S4 block with normalization, activation, and residual connection"""
    def __init__(self, d_model, N=64, l_max=1024, discretization='zoh'):
        super().__init__()
        self.s4 = S4Layer(d_model, N, l_max, discretization)
        self.norm = nn.LayerNorm(d_model)
        self.activation = nn.GELU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, x):
        # Pre-normalization
        z = self.norm(x)

        # Apply S4 layer
        z = self.s4(z)

        # Activation and dropout
        z = self.activation(z)
        z = self.dropout(z)

        # Residual connection
        return x + z


class S4Model(nn.Module):
    """Complete S4 model for classification"""
    def __init__(self, d_input, d_model=64, n_layers=2, d_output=10, discretization='zoh'):
        super().__init__()
        self.encoder = nn.Linear(d_input, d_model)

        # Stack S4 blocks
        self.blocks = nn.ModuleList([
            S4Block(d_model, discretization=discretization)
            for _ in range(n_layers)
        ])

        # Classification head
        self.decoder = nn.Linear(d_model, d_output)

    def forward(self, x):
        # Encode input
        x = self.encoder(x)  # [batch, seq_len, d_model]

        # Apply S4 blocks
        for block in self.blocks:
            x = block(x)

        # Global average pooling
        x = x.mean(dim=1)  # [batch, d_model]

        # Final classification
        return self.decoder(x)  # [batch, d_output]


class EarlyStopping:
    def __init__(self, patience=3, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = float('inf')

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            return False
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return True
            return False


def train_epoch(model, trainloader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for inputs, targets in tqdm(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return total_loss / len(trainloader), 100. * correct / total


def evaluate(model, testloader, criterion, device):
    """Evaluate model on test set"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    return total_loss / len(testloader), 100. * correct / total


def main():
    # Parameters - reduced for faster training
    batch_size = 64
    d_model = 64  
    n_layers = 4 
    epochs = 5   
    fraction = 0.5  
    discretization_schemes = ['zoh', 'bilinear', 'generalized']

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load data with subset for faster training
    trainloader, testloader, d_input = load_sequential_cifar10(batch_size, fraction)
    print(f"Data loaded: sequence length = 1024, d_input = {d_input}, using {fraction*100}% of training data")

    # Model architecture summary
    print(f"Model architecture summary:")
    dummy_input = torch.zeros(1, 1024, d_input).to(device)
    model = S4Model(d_input, d_model, n_layers, 10, 'zoh').to(device)
    print(f"Total parameters: {sum(p.numel() for p in model.parameters())}")

    # Experiment with different discretization schemes
    results = {}

    for scheme in discretization_schemes:
        print(f"\n{'='*50}")
        print(f"Training with {scheme} discretization")
        print(f"{'='*50}")

    
        model = S4Model(
            d_input=d_input,
            d_model=d_model,
            n_layers=n_layers,
            d_output=10,
            discretization=scheme
        ).to(device)

        # Setup optimizer with different learning rates
        # S4 parameters need smaller learning rate with no weight decay
        s4_params = []
        other_params = []
        for name, param in model.named_parameters():
            if 's4.' in name and any(p in name for p in ['B', 'C', 'D', 'log_step']):
                s4_params.append(param)
            else:
                other_params.append(param)

        optimizer = optim.AdamW([
            {'params': s4_params, 'lr': 0.001, 'weight_decay': 0.0},
            {'params': other_params, 'lr': 0.004, 'weight_decay': 0.01}
        ])
        criterion = nn.CrossEntropyLoss()

        # Early stopping
        early_stopping = EarlyStopping(patience=3)

        # Training loop
        best_acc = 0
        for epoch in range(epochs):
            # Train
            train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device)

            # Evaluate
            test_loss, test_acc = evaluate(model, testloader, criterion, device)

            print(f"Epoch {epoch+1}/{epochs}: "
                  f"Train Loss={train_loss:.4f}, Train Acc={train_acc:.2f}%, "
                  f"Test Loss={test_loss:.4f}, Test Acc={test_acc:.2f}%")

            if test_acc > best_acc:
                best_acc = test_acc

            # Check for early stopping
            if early_stopping(test_loss):
                print(f"Early stopping triggered at epoch {epoch+1}")
                break

        # Store results
        results[scheme] = best_acc
        print(f"Best accuracy with {scheme}: {best_acc:.2f}%")

    # Compare results
    print("\n--- Final Results ---")
    for scheme, acc in results.items():
        print(f"{scheme}: {acc:.2f}%")


if __name__ == "__main__":
    main()


Using device: cuda
Files already downloaded and verified
Files already downloaded and verified
Data loaded: sequence length = 1024, d_input = 3, using 50.0% of training data
Model architecture summary:
Total parameters: 1938

Training with zoh discretization


100%|██████████| 391/391 [12:16<00:00,  1.88s/it]


Epoch 1/5: Train Loss=2.0167, Train Acc=24.60%, Test Loss=1.8611, Test Acc=32.39%


100%|██████████| 391/391 [12:15<00:00,  1.88s/it]


Epoch 2/5: Train Loss=1.8784, Train Acc=30.41%, Test Loss=1.8347, Test Acc=32.47%


100%|██████████| 391/391 [12:14<00:00,  1.88s/it]


Epoch 3/5: Train Loss=1.8486, Train Acc=31.46%, Test Loss=1.8301, Test Acc=32.90%


100%|██████████| 391/391 [12:15<00:00,  1.88s/it]


Epoch 4/5: Train Loss=1.8236, Train Acc=32.82%, Test Loss=1.8126, Test Acc=33.88%


100%|██████████| 391/391 [12:13<00:00,  1.88s/it]


Epoch 5/5: Train Loss=1.8010, Train Acc=33.67%, Test Loss=1.7835, Test Acc=34.71%
Best accuracy with zoh: 34.71%

Training with bilinear discretization


100%|██████████| 391/391 [12:11<00:00,  1.87s/it]


Epoch 1/5: Train Loss=2.2791, Train Acc=19.75%, Test Loss=2.1944, Test Acc=22.70%


100%|██████████| 391/391 [12:09<00:00,  1.87s/it]


Epoch 2/5: Train Loss=2.0772, Train Acc=23.68%, Test Loss=2.2016, Test Acc=21.97%


100%|██████████| 391/391 [12:11<00:00,  1.87s/it]


Epoch 3/5: Train Loss=2.0072, Train Acc=25.54%, Test Loss=2.1096, Test Acc=24.57%


100%|██████████| 391/391 [12:14<00:00,  1.88s/it]


Epoch 4/5: Train Loss=1.9611, Train Acc=27.27%, Test Loss=2.0382, Test Acc=24.69%


100%|██████████| 391/391 [12:12<00:00,  1.87s/it]


Epoch 5/5: Train Loss=1.9210, Train Acc=29.00%, Test Loss=2.1967, Test Acc=20.36%
Best accuracy with bilinear: 24.69%

Training with generalized discretization


100%|██████████| 391/391 [12:11<00:00,  1.87s/it]


Epoch 1/5: Train Loss=2.2106, Train Acc=20.75%, Test Loss=2.0100, Test Acc=25.39%


100%|██████████| 391/391 [12:15<00:00,  1.88s/it]


Epoch 2/5: Train Loss=2.0007, Train Acc=25.75%, Test Loss=1.9443, Test Acc=28.44%


100%|██████████| 391/391 [12:11<00:00,  1.87s/it]


Epoch 3/5: Train Loss=1.9432, Train Acc=28.38%, Test Loss=1.9636, Test Acc=30.28%


100%|██████████| 391/391 [12:10<00:00,  1.87s/it]


Epoch 4/5: Train Loss=1.9202, Train Acc=29.12%, Test Loss=2.0019, Test Acc=27.08%


100%|██████████| 391/391 [12:10<00:00,  1.87s/it]


Epoch 5/5: Train Loss=1.9000, Train Acc=29.72%, Test Loss=1.8942, Test Acc=32.43%
Best accuracy with generalized: 32.43%

--- Final Results ---
zoh: 34.71%
bilinear: 24.69%
generalized: 32.43%
