###Overview
 This file is a variation of Mamba-tiny by https://github.com/PeaBrane/mamba-tiny/tree/master

 It contains a compact version of Mamba which is fully functional, but does not implement the efficient GPU parts use or parallel scan.

I addition to the changes done in 'Mamba model', this model has the additional functionality of running sequentially, so that in inference the new datapoints ar entered one by one. This means that the hidden states are handled such that they can be saved from step to step and not start automatically from zeros.\
To work in this mode, the model needs to be intiated with 'generate = True' in the model args configuration.\
If we wish to reset the model to start a new sequence, we can execute the 'reset_h' method.

After constructing this model variation we added the 'generate_k_last_predictions' function and made sure that we get the same results for runnig it with k=1 as for the non generative model operation.

In [None]:
import io
import os
import sys
import copy
from datetime import datetime
import pickle
import warnings
warnings.filterwarnings('ignore')
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from google.colab import drive
drive.mount('/content/drive/')
# helper files
sys.path.append('/content/drive/MyDrive/Final Project UAV/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [None]:
cd /content/drive/MyDrive/Final Project UAV/

/content/drive/MyDrive/Final Project UAV


In [None]:
"""Simple, minimal implementation of Mamba in one file of PyTorch.

Suggest reading the following before/while reading the code:
    [1] Mamba: Linear-Time Sequence Modeling with Selective State Spaces (Albert Gu and Tri Dao)
        https://arxiv.org/abs/2312.00752
    [2] The Annotated S4 (Sasha Rush and Sidd Karamcheti)
        https://srush.github.io/annotated-s4

Glossary:
    b: batch size                       (`B` in Mamba paper [1] Algorithm 2)
    l: sequence length                  (`L` in [1] Algorithm 2)
    d or d_model: hidden dim
    n or d_state: latent state dim      (`N` in [1] Algorithm 2)
    expand: expansion factor            (`E` in [1] Section 3.4)
    d_in or d_inner: d * expand         (`D` in [1] Algorithm 2)
    A, B, C, D: state space parameters  (See any state space representation formula)
                                        (B, C are input-dependent (aka selective, a key innovation in Mamba); A, D are not)
    Δ or delta: input-dependent step size
    dt_rank: rank of Δ                  (See [1] Section 3.6 "Parameterization of ∆")

"""
!pip install einops
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.8.0


In [None]:
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    d_state: int = 16
    expand: int = 2
    d_conv: int = 4
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
    generate: bool = False
    h_revision: int = 0

    def reset_h(self):
        self.h_revision += 1

    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)


In [None]:
class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

    def forward(self, x, delta, h = None):
        """
        Args:
            x (long tensor): shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta (float tensor): shape (b, l, 1) --> this is the added time interval vector
        Returns:
            output: shape (b, l, d) (prediction of future values)

        Official Implementation:
            class MambaLMHeadModel, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/models/mixer_seq_simple.py#L173

        """
        for layer in self.layers:
            x = layer(x, delta)

        output = self.norm_f(x)

        return output

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """Simple block wrapping Mamba block with normalization and residual connection."""
        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)


    def forward(self, x, delta):
        """
        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, 1) --> this is the added time interval vector
        Returns:
            output: shape (b, l, d)

        Official Implementation:
            Block.forward(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L297

            Note: the official repo chains residual blocks that look like
                [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> [Add -> Norm -> Mamba] -> ...
            where the first Add is a no-op. This is purely for performance reasons as this
            allows them to fuse the Add->Norm.

            We instead implement our blocks as the more familiar, simpler, and numerically equivalent
                [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> [Norm -> Mamba -> Add] -> ....

        """
        output = self.mixer(self.norm(x), delta) + x

        return output

In [None]:
class MambaBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        """A single Mamba block, as described in Figure 3 in Section 3.4 in the Mamba paper [1]."""
        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )

        # x_proj takes in `x` and outputs the input-specific B, C
        self.x_proj = nn.Linear(args.d_inner, args.d_state * 2, bias=False)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)
        self.h = None
        self.h_revision = args.h_revision

    def forward(self, x, delta):
        """Mamba block forward. This looks the same as Figure 3 in Section 3.4 in the Mamba paper [1].

        Args:
            x: shape (b, l, d)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, 1)
        Returns:
            output: shape (b, l, d)

        Official Implementation:
            class Mamba, https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

        """
        (b, l, d) = x.shape

        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')

        x = F.silu(x)

        y = self.ssm(x, delta)

        y = y * F.silu(res)

        output = self.out_proj(y)

        return output


    def ssm(self, x, delta):
        """Runs the SSM. See:
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, u) in The Annotated S4 [2]

        Args:
            x: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, 1)
        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            mamba_inner_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L311

        """
        (d_in, n) = self.A_log.shape

        # Compute A B C D, the state space parameters.
        #     A, D are input independent (see Mamba paper [1] Section 3.5.2 "Interpretation of A" for why A isn't selective)
        #     B, C are input-dependent (this is a key difference between Mamba and the linear time invariant S4,
        #                                  and is why Mamba is called **selective** state spaces)

        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()
        delta = torch.unsqueeze(delta, -1).repeat(1, 1, d_in)  # shape (b, l, d_in) --> using the same delta to all parallel inputs

        x_bl = self.x_proj(x)  # (b, l, 2*n)
        (B, C) = x_bl.split(split_size=[n, n], dim=-1)  # B, C: (b, l, n)

        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]

        return y


    def selective_scan(self, x, delta, A, B, C, D):
        """Does selective scan algorithm. See:
            - Section 2 State Space Models in the Mamba paper [1]
            - Algorithm 2 in Section 3.2 in the Mamba paper [1]
            - run_SSM(A, B, C, x) in The Annotated S4 [2]

        This is the classic discrete state space formula:
            h(t + 1) = Ah(t) + Bx(t)
            y(t)     = Ch(t) + Dx(t)
        except B and C (and the step size delta, which is used for discretization) are dependent on the input x(t).

        Args:
            u: shape (b, l, d_in)    (See Glossary at top for definitions of b, l, d_in, n...)
            delta: shape (b, l, d_in)
            A: shape (d_in, n)
            B: shape (b, l, n)
            C: shape (b, l, n)
            D: shape (d_in,)

        Returns:
            output: shape (b, l, d_in)

        Official Implementation:
            selective_scan_ref(), https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py#L86
            Note: I refactored some parts out of `selective_scan_ref` out, so the functionality doesn't match exactly.

        """
        (b, l, d_in) = x.shape
        n = A.shape[1]

        # Discretize continuous parameters (A, B)
        # - A is discretized using zero-order hold (ZOH) discretization (see Section 2 Equation 4 in the Mamba paper [1])
        # - B is discretized using a simplified Euler discretization instead of ZOH. From a discussion with authors:
        #   "A is the more important term and the performance doesn't change much with the simplification on B"
        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_x = einsum(delta, B, x, 'b l d_in, b l n, b l d_in -> b l d_in n')

        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])
        # Note that the below is sequential, while the official implementation does a much faster parallel scan that
        # is additionally hardware-aware (like FlashAttention).

        # Handling the hidden state with the user settings.
        # Either null h or load it from the last recorded self.h
        if self.args.h_revision > self.h_revision:
          self.h  = None
          self.h_revision = self.args.h_revision
          # resetting h

        if self.h is None or not self.args.generate:
          h = torch.zeros((b, d_in, n), device=deltaA.device)
        else:
          h = self.h

        ys = []
        for i in range(l):
            h = deltaA[:, i] * h + deltaB_x[:, i]
            y = einsum(h, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        self.h = h
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        y = y + x * D

        return y

In [None]:
class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output


Testing the model on our samples

In [None]:
save_path = './Samples/mamba_samples_testing0_' + samples_config['subfolder'] + '_samples'
with open(save_path , 'rb') as f:
  train_samples, test_samples, train_samples_filenames, test_samples_filenames, train_dt, test_dt = pickle.load(f)

In [None]:
test_data = list(zip(train_samples, train_dt))
dl_test0 = DataLoader(test_data, batch_size = 1)
sample, dt = next(iter(dl_test0))
sample.shape

torch.Size([1, 251, 8])

In [None]:
d_model = 8
n_layer = 1
d_state = 16
expand = 2
d_conv = 4
# conv_bias = True
# bias = False
generate = True
args = ModelArgs(d_model, n_layer, d_state, expand, d_conv, generate=generate)

In [None]:
model1 = Mamba(args)
model2 = copy.deepcopy(model1)

In [None]:
output = model1(sample, dt)

In [None]:
output.shape

In [None]:
print(output[0,-10:, 0])

tensor([-2.7939, -2.8003, -2.7975, -2.7481, -2.7613, -2.7999, -2.8008, -2.7924,
        -2.7449, -2.7335], grad_fn=<SelectBackward0>)


###Generate

In [None]:
def generate_k_last_predictions(model, k, sample, dt):
  sample_header = sample[:, :-k, :]
  dt_header = dt[:, :-k]
  dt_tail = dt[:, -k:]

  with torch.no_grad():
    output = model(sample_header, dt_header)
    y = output[:, -1, :].unsqueeze(1) #generating from last prediction
    pred = [y]
    for i in range(k):
      single_dt = dt_tail[:, i].unsqueeze(1)
      y = model(y, single_dt)
      pred.append(y)
    predictions = torch.cat(pred, dim=1)
      # !note that the last prediction here does not have a label and should not be used. returning list of k+1 with only first k relevant
  return predictions

In [None]:
predictions = generate_k_last_predictions(model2, 5, sample, dt)

In [None]:
predictions