## State Space Models 
# Major Takeaways:
  - This notebook demonstrates a simple State Space Model (SSM) for time series modeling.
Unlike RNNs or LSTMs, the SSM computes its output using parallel 1D convolutions instead of sequential unrolling through time, showcasing improved computational efficiency.
  - The implemented SSM is a minimal linear variant with a diagonal state matrix, so it serves as a baseline model. We expect it to have lower modeling capacity compared to nonlinear RNNs or LSTMs.
  - We will train both SSM and RNN models on the same dataset and analyze the differences in their performance, learning dynamics, and computational efficiency

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_auc_score
from utils import get_validation_score
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt

In [None]:
class SSMLayer(nn.Module):
    """
    Diagonal State Space Model (SSM) layer.

    Params:
      - n: hidden/state dimension (size of diagonal A)
      - d_in: input dimension (per timestep)
      - d_out: output dimension (per timestep)
      - max_len: maximum sequence length supported (used for kernel computation)

    Input to forward: x with shape (batch, d_in, T)
    Output: y with shape (batch, d_out, T)
    """
    def __init__(self, n, d_in, d_out, max_len=512):
        super().__init__()
        self.n = n
        self.d_in = d_in
        self.d_out = d_out
        self.max_len = max_len

        # unconstrained param -> squashed to (-1,1) for stability
        self.logit_a = nn.Parameter(-3.0 * torch.ones(n))
        self.B = nn.Parameter(torch.randn(n, d_in) * 0.1)     # (n, d_in)
        self.C = nn.Parameter(torch.randn(d_out, n) * 0.1)    # (d_out, n)
        self.D = nn.Parameter(torch.zeros(d_out, d_in))      # (d_out, d_in)

    def _get_a(self):
        # keep eigenvalues in (-1,1)
        # optionally multiply by 0.999 to avoid exact Â±1
        return 0.999 * torch.tanh(self.logit_a)

    def compute_kernel(self, T):
        """Return K with shape (d_out, d_in, T)."""
        if T > self.max_len:
            self.max_len = T

        a = self._get_a()                           # (n,)
        ks = torch.arange(T, dtype=a.dtype, device=a.device)   # (T,)
        powers = a.unsqueeze(0).pow(ks.unsqueeze(1))           # (T, n)
        AB = powers.unsqueeze(2) * self.B.unsqueeze(0)         # (T, n, d_in)
        K_t = torch.einsum('on,tni->toi', self.C, AB)          # (T, d_out, d_in)
        K_t[0] = K_t[0] + self.D                               # K[0] += D
        K = K_t.permute(1, 2, 0).contiguous()                  # (d_out, d_in, T)
        return K

    def forward(self, x):
        """
        x: (batch, d_in, T)
        returns y: (batch, d_out, T)
        """
        T = x.shape[-1]
        K = self.compute_kernel(T)                    # (d_out, d_in, T)
        K_rev = torch.flip(K, dims=[2])               # conv1d performs cross-correlation -> reverse kernel
        pad = K_rev.shape[2] - 1
        # Pad left with zeros to make conv causal and preserve length
        x_padded = F.pad(x, (pad, 0))                 # pad=(left, right) for last dim
        y = F.conv1d(x_padded, K_rev)                 # (batch, d_out, T)
        return y



In [None]:
class SSMClassifier(nn.Module):
    """
    Drop-in replacement for your RNNClassifier:
      - constructor signature: (input_dim, hidden_dim, device)
      - forward(x) expects (batch, time_steps, features)
      - returns sigmoid output of shape (batch, 1)
    """
    def __init__(self, input_dim, hidden_dim, device):
        super().__init__()
        self.hidden_dim = hidden_dim
        # note SSMLayer signature (n, d_in, d_out)
        # we want per-timestep 'hidden' vectors of dim hidden_dim, so set d_out = hidden_dim
        self.ssm = SSMLayer(hidden_dim, input_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, 1)
        self.activation = nn.Sigmoid()
        self.device = device
        self.to(device)

    def forward(self, x):
        """
        x: (batch, time_steps, features)
        returns: (batch, 1) with sigmoid activation
        """
        # Move to device if needed (optional if caller already does)
        x = x.to(self.device)

        # Permute to conv format: (batch, features, time)
        x_perm = x.permute(0, 2, 1).contiguous()   # (B, d_in, T)

        # SSM forward -> outputs per timestep hidden vectors (B, hidden_dim, T)
        y_seq = self.ssm(x_perm)                   # (B, hidden_dim, T)

        # Choose how to map sequence -> single vector for classification:
        # To mimic the RNNClassifier (which uses the last hidden state), pick last timestep:
        hidden_last = y_seq[:, :, -1]              # (B, hidden_dim)

        # Optionally you could do mean pooling: hidden_last = y_seq.mean(dim=2)

        out = self.fc(hidden_last)                 # (B, 1)
        return self.activation(out)

In [None]:
### For comparison

class RNNClassifier(nn.Module):
    
    def __init__(self, input_dim, hidden_dim,device):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.rnn = nn.RNNCell(input_dim, hidden_dim)  # RNN Cell
        self.fc = nn.Linear(hidden_dim, 1)            # fully connected layer: maps last hidden vector to model prediction
        self.activation = nn.Sigmoid()                # coz binary classification
        self.device=device
    
    def forward(self, x):

        hidden = self.init_hidden(x)
        
        ############################# 
        
        # Write you code here.
        # Return expects variable out. Its the hidden vector obtained after last time-step.
        
        time_steps=x.shape[1]                 # shape of x is (batches,time_Steps,features)
        
        for i in range(0,time_steps):
            inputs=x[:,i]                     # (batch,features) shape
            hidden = self.rnn(inputs,hidden)
            
        out = self.fc(hidden)                 # take the hidden vector corresponding to last time step
        ###########################
        
        return self.activation(out)
    
    def init_hidden(self, x):
        h0 = torch.zeros(x.size(0), self.hidden_dim)
        return h0.to(self.device)

In [None]:
# Function to compute validation score: Used in trainer()

def get_validation_score(model,Val_T,Val_L):
    model.eval()
    tensor_x = torch.Tensor(Val_T).to(model.device)
    preds=model(tensor_x)[:,0]
    LOSS=nn.BCELoss().to(device)
    val_loss=LOSS(preds,torch.Tensor(Val_L).type(torch.FloatTensor).to(model.device))
    return roc_auc_score(Val_L, preds.cpu().detach().numpy()), val_loss

In [None]:
def trainer(model,training_set,validation_set,device,lr,stored_name,epochs=10):
    
    # Recieves data and labels 
    T,L=training_set 
    Val_T,Val_L=validation_set
    
    # intialise optimiser and criterion
    
    optimizer_model = torch.optim.SGD(model.parameters(),lr,momentum=0.9, nesterov=True)
    criterion = nn.BCELoss().to(device)

    # 
    best=0
    LOSS=[]
    VAL_LOSS=[]
    
    # training begins
    
    for epoch in range(0,epochs):
        Loss=0
        model.train()
        for k in range(0,len(T)):
            
            inputs=T[k]
            labels=L[k]
            
            inputs=torch.Tensor(inputs).to(device)
            labels=torch.Tensor(labels).type(torch.FloatTensor).to(device)
            
            pred=model(inputs)
            
            loss=criterion(pred[:,0],labels)
            optimizer_model.zero_grad()
            loss.backward()
            optimizer_model.step()
            Loss=Loss+loss
           
        Val_ROC,val_loss=get_validation_score(model,Val_T,Val_L)
        VAL_LOSS.append(val_loss.detach().cpu().numpy())
        LOSS.append((Loss/len(T)).detach().cpu().numpy())
        
        print(' Epoch: {:.1f} Training Loss {:5f} Validation Loss {:.4f} Validation AUC {:.5f}'.format(epoch,LOSS[-1],VAL_LOSS[-1],Val_ROC))
        
        # If current validation score is greater than best, store the model
        
        if best<Val_ROC:
           torch.save(model, './'+stored_name) 

    return torch.load('./'+stored_name).to(device),LOSS,VAL_LOSS     

In [None]:
# Get training and validation data

from get_data import get_training_data,get_validation_data

T,L=get_training_data(batch_size=32)  # returns lists of training data and label batches
Val_T,Val_L=get_validation_data()     # numpy arrays of validation data and labels

print(T[0].shape)                     # (batch_size,time_steps,n_features)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
n_features=T[0].shape[2]     # 76 dimensional vector at each time step
recurrent_units=128          # number of hidden units in  a RNN/LSTM
lr=0.001                     # learning rate 

In [None]:
# Create LSTMClassifier model object

model=SSMClassifier(n_features,recurrent_units,device)
print(model)

In [None]:
# Train SSM

model=model.to(device)
model, training_loss, validation_loss=trainer(model,(T,L),(Val_T,Val_L),device,lr,stored_name='ssm',epochs=50)

In [None]:
# Plot training and validation loss
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(5,5))

lw = 2

plt.tight_layout()
ax.plot(np.linspace(1, len(training_loss), num=len(training_loss)),training_loss, color='rebeccapurple',
         lw=2, linestyle='-', label='Training Loss')

ax.plot(np.linspace(1, len(training_loss), num=len(training_loss)),validation_loss, color='r',
         lw=2, linestyle='-', label='Validation_loss')

ax.set_xlabel('# Epochs',fontsize=14)
ax.set_ylabel('# Loss',fontsize=14)
ax.legend(loc="best",fontsize=12)

ax.tick_params(axis='x', labelsize=13)
ax.tick_params(axis='y', labelsize=13)


plt.grid(color='gray', linestyle='-', linewidth=0.1)


plt.show()

In [None]:
# Create RNNClassifier model object

model=RNNClassifier(n_features,recurrent_units,device)
print(model)


In [None]:
# Train RNNClassifier

model=model.to(device)
model, training_loss, validation_loss=trainer(model,(T,L),(Val_T,Val_L),device,lr,stored_name='rnn',epochs=50)

In [None]:
# Plot training and validation loss

fig, ax = plt.subplots(figsize=(5,5))

lw = 2

plt.tight_layout()
ax.plot(np.linspace(1, len(training_loss), num=len(training_loss)),training_loss, color='rebeccapurple',
         lw=2, linestyle='-', label='Training Loss')

ax.plot(np.linspace(1, len(training_loss), num=len(training_loss)),validation_loss, color='r',
         lw=2, linestyle='-', label='Validation_loss')

ax.set_xlabel('# Epochs',fontsize=14)
ax.set_ylabel('# Loss',fontsize=14)
ax.legend(loc="best",fontsize=12)

ax.tick_params(axis='x', labelsize=13)
ax.tick_params(axis='y', labelsize=13)


plt.grid(color='gray', linestyle='-', linewidth=0.1)


#plt.savefig('./loss_curves.pdf',dpi=100,bbox_inches='tight')
plt.show()