In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from torch_optimizer import Lookahead


from data_loader import load_speech_commands_dataset, load_bg_noise_dataset
from utils import set_memory_GB,print_model_size, log_to_file
from augmentations import add_time_shift_and_align, add_silence
from train_utils import trainig_loop





  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):
  def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True):
  def backward(ctx, grad_output):
  def forward(ctx, zxbcdt, conv1d_weight, conv1d_bias, dt_bias, A, D, chunk_size, initial_states=None, seq_idx=None, dt_limit=(0.0, float("inf")), return_final_states=False, activation="silu",
  def backward(ctx, dout, *args):


In [2]:
configs = {'d_state': 51, 'd_conv': 10, 'expand': 2, 'batch_size': 26, 'dropout_rate': 0.134439213335519, 'num_mamba_layers': 2, 'n_mfcc': 23, 'n_fft': 475, 'hop_length': 119, 'n_mels': 61, 'noise_level': 0.2582577623788829, 'lr': 0.0011942156978344588, 'weight_decay': 2.5617519345807027e-05}


dataset = {
    'fixed_length': 16000,
    'n_mfcc': configs['n_mfcc'],  # Use from configs
    'n_fft': configs['n_fft'],    # Use from configs
    'hop_length': configs['hop_length'],  # Use from configs
    'n_mels': configs['n_mels'],  # Use from configs
    'noise_level': configs['noise_level']  # Use from configs
}


model_configs = {

    'input_dim': configs['n_mfcc'] * 3,  # Use from configs
    'd_model': (dataset['fixed_length'] // dataset['hop_length']) + 1 + 1,  # Use from configs
    'd_state': configs['d_state'],  # Use from configs
    'd_conv': configs['d_conv'],    # Use from configs
    'expand': configs['expand'],    # Use from configs
    'num_mamba_layers': configs['num_mamba_layers'],  # Use from configs
    'dropout_rate': configs['dropout_rate'],  # Use from configs
    'label_names': ['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes']
}

In [3]:
torch.cuda.is_available()

True

In [4]:
set_memory_GB(1)

Memory fraction set to 0.022458079576498518
Memory fraction in GB: 1.0


In [5]:
train_ds, val_ds, test_ds, silence_ds , info = load_speech_commands_dataset()
# bg_noise_ds = load_bg_noise_dataset()
bg_noise_ds = None

2024-09-26 15:31:47.618479: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-09-26 15:31:47.625593: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-09-26 15:31:47.639590: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-26 15:31:47.662105: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-26 15:31:47.668486: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attemptin

In [6]:
# maintain seed for repructablity
np.seed = 42
# tf.random.set_seed(42)
torch.manual_seed(0)

<torch._C.Generator at 0x7f01a6db63d0>

In [7]:
label_names = ['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes']
print(label_names)

['down', 'go', 'left', 'no', 'off', 'on', 'right', 'stop', 'up', 'yes']


In [8]:
augmentations = [
    lambda x: add_time_shift_and_align(x),
]

In [9]:
import torch
import numpy as np
import random
from torch.utils.data import Dataset
from librosa.feature import mfcc, delta

class TFDatasetAdapter(Dataset):
    def __init__(self, tf_dataset, bg_noise_dataset, fixed_length, n_mfcc, n_fft, hop_length, n_mels, augmentation=False, derivative=True, noise_level=0.3, MFCC_transform=True, quantize_8bit=False):
        self.tf_dataset = tf_dataset
        self.data = list(tf_dataset)
        self.bg_noise_data = list(bg_noise_dataset) if bg_noise_dataset is not None else None
        self.fixed_length = fixed_length
        self.n_mfcc = n_mfcc
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.n_mels = n_mels
        self.augmentation = augmentation
        self.derivative = derivative
        self.noise_level = noise_level
        self.MFCC_transform = MFCC_transform
        self.quantize_8bit = quantize_8bit  # New parameter for 8-bit quantization

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        audio, label = self.data[idx]
        audio = audio.numpy()

        # Normalize the audio tensor
        audio = audio / np.max(np.abs(audio))

        # Convert to float
        audio = audio.astype(np.float32)

        # Ensure the audio tensor has the correct shape (1D array)
        if audio.ndim > 1:
            audio = np.squeeze(audio)

        # Add noise from bg_noise data
        if self.bg_noise_data:
            bg_noise_audio = random.choice(self.bg_noise_data)

            # Trim or pad bg_noise to match the audio length
            if len(bg_noise_audio) < len(audio):
                bg_noise_audio = np.pad(bg_noise_audio, (0, len(audio) - len(bg_noise_audio)), mode='constant')
            else:
                # Take a random slice of bg_noise_audio with the same length as the original audio
                start_idx = random.randint(0, len(bg_noise_audio) - len(audio))
                bg_noise_audio = bg_noise_audio[start_idx:start_idx + len(audio)]

            # Add bg_noise as noise to the original audio
            audio = audio + self.noise_level * bg_noise_audio

        # Pad or trim the audio to the fixed length
        if len(audio) < self.fixed_length:
            audio = np.pad(audio, (0, self.fixed_length - len(audio)), mode='constant')
        else:
            audio = audio[:self.fixed_length]

        output = audio

        # Apply augmentations if any
        if self.augmentation:
            for aug in self.augmentation:
                audio = aug(audio)

        # Apply MFCC transformation if enabled
        if self.MFCC_transform:
            audio = audio.astype(np.float32)
            MFCC = mfcc(y=audio, sr=16000, n_mfcc=self.n_mfcc, n_fft=self.n_fft, hop_length=self.hop_length, n_mels=self.n_mels)
            
            if self.derivative:
                # Create MFCC first and second-order deltas
                MFCC_delta = delta(MFCC)
                MFCC_delta2 = delta(MFCC, order=2)

                # Stack MFCC with its deltas
                MFCC = np.vstack([MFCC, MFCC_delta, MFCC_delta2])

            # Remove extra dimension if it exists
            if output.ndim == 3:
                MFCC = MFCC.squeeze(-1)

            output = MFCC

        # Apply 8-bit quantization if the option is enabled
        if self.quantize_8bit:
            output = (output * 127).astype(np.int8)  # Scale float32 to int8 range (-128 to 127)

        return torch.tensor(output, dtype=torch.float32 if not self.quantize_8bit else torch.int8), torch.tensor(label.numpy(), dtype=torch.long)

In [10]:
# Convert the TFDS dataset to a PyTorch Dataset
fixed_length = 16000
n_mfcc = 13
n_fft = 640
hop_length = 80
n_mels = 100
#take just 10 of the dataset
train_ds = train_ds.take(1000)
val_ds = val_ds.take(100)

# Initialize datasets with configurations
pytorch_train_dataset = TFDatasetAdapter(train_ds, bg_noise_ds, **dataset, augmentation=[lambda x: add_time_shift_and_align(x)])
pytorch_val_dataset = TFDatasetAdapter(val_ds, None, **dataset, augmentation=None)

2024-09-26 15:31:51.523732: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
2024-09-26 15:31:51.806392: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


In [11]:
# #play sound from dataset
# import IPython.display as ipd

# for i in range(10):
#     x, y = pytorch_train_dataset[i]
#     print(label_names[y])
#     ipd.display(ipd.Audio(x.numpy(), rate=16000))
#     # print(x.shape)

In [12]:
# Create a DataLoader to feed the data into the model
batch_size = 32
train_loader = DataLoader(pytorch_train_dataset, batch_size=batch_size, shuffle=True,num_workers=4,prefetch_factor=2)
val_loader = DataLoader(pytorch_val_dataset, batch_size=batch_size, shuffle=False,num_workers=4,prefetch_factor=2)

# RNN based SSM

In [13]:
import torch
import torch.nn as nn

# Define the RNN-to-SSM transformation in PyTorch

class SSM(nn.Module):
    def __init__(self, input_size, state_size, output_size):
        super(SSM, self).__init__()
        
        # SSM matrices (A, B, C, D) as parameters
        self.A = nn.Parameter(torch.randn(state_size, state_size))
        self.B = nn.Parameter(torch.randn(state_size, input_size))
        self.C = nn.Parameter(torch.randn(output_size, state_size))
        self.D = nn.Parameter(torch.zeros(output_size, input_size))
        
        # Initialize the state (equivalent to hidden state in RNN)
        self.state = torch.zeros(state_size, 1)

    def forward(self, x):
        device = x.device  # Ensure state is on the same device as input
        self.state = self.state.to(device)
        self.A = self.A.to(device)
        self.B = self.B.to(device)
        self.C = self.C.to(device)
        self.D = self.D.to(device)
        
        # x is the input sequence (batch_size, sequence_length, input_size)
        batch_size, seq_len, input_size = x.size()
        outputs = []
        
        # Iterate over the sequence
        for t in range(seq_len):
            u_t = x[:, t, :].unsqueeze(-1)  # Current input (batch_size, input_size, 1)
            
            # State update equation: x_{t+1} = A * x_t + B * u_t
            self.state = self.A @ self.state + self.B @ u_t
            
            # Output equation: y_t = C * x_t + D * u_t
            y_t = self.C @ self.state + self.D @ u_t
            outputs.append(y_t.squeeze(-1))
        
        return torch.stack(outputs, dim=1)  # (batch_size, sequence_length, output_size)


In [14]:

# Example usage
input_size = 1   # Input dimension (e.g., 1D time series)
state_size = 2   # State dimension (hidden state size)
output_size = 1  # Output dimension

# Initialize the SSM model
ssm_model = SSM(input_size=input_size, state_size=state_size, output_size=output_size)

# Create a dummy input sequence (batch_size=1, sequence_length=10, input_size=1)
input_sequence = torch.sin(torch.linspace(0, 2 * torch.pi, 10)).unsqueeze(0).unsqueeze(-1)

# Forward pass through the SSM
output_sequence = ssm_model(input_sequence)
print(output_sequence)


tensor([[[  0.0000],
         [ -1.0346],
         [ -1.0673],
         [  0.7362],
         [  4.8928],
         [ 12.1339],
         [ 24.3951],
         [ 46.5356],
         [ 89.2888],
         [174.6364]]], grad_fn=<StackBackward0>)


In [15]:
import torch
import torch.nn as nn

class Mamba(nn.Module):
    def __init__(self, d_model, d_state, d_conv, expand):
        super(Mamba, self).__init__()
        self.expand = expand
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.d_inner = int(self.expand * self.d_model)

        # Define the layers for the MambaBlock
        self.x_proj = nn.Linear(self.d_model, self.d_inner)
        self.x_res_proj = nn.Linear(self.d_model, self.d_inner)

        # Set padding to "same" to maintain input size
        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,  # Conv1d in_channels will match d_model after projection
            out_channels=self.d_inner,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding='same'  # Ensure the sequence length stays the same
        )

        # Placeholder SSM 
        self.ssm_model = SSM(input_size=self.d_inner, state_size=d_state, output_size=self.d_inner)
        # Activation function SiLU (Swish)
        self.activation = nn.SiLU()

        # Output projection layer
        self.out_proj = nn.Linear(self.d_inner, self.d_model)

    def mambaBlock(self, x):
        # x is of shape (batch_size, seq_len, input_dim)
        # Print the shape of x before projection
        # print(f"Shape of x before projection: {x.shape}")
        # Apply input projection layers
        x = x.transpose(1, 2)  # Transpose to (batch_size, d_inner, seq_len)
        x_res = self.x_res_proj(x)  # Residual connection
        x = self.x_proj(x)          # Main path projection

        x = x.permute(0, 2, 1)  # Transpose to (batch_size, seq_len, d_inner)

        
        # Apply Conv1d
        x = self.conv1d(x)

        # Transpose back to (batch_size, seq_len, d_inner)
        x = x.transpose(1, 2)

        # Apply activation function
        x = self.activation(x)
        x_res = self.activation(x_res)  # Optional: Activation on residual path

        # Apply SSM layer
        x = self.ssm_model(x)

        # Ensure x_res and x have the same size along the sequence dimension
        if x.size(1) != x_res.size(1):
            min_len = min(x.size(1), x_res.size(1))
            x = x[:, :min_len, :]
            x_res = x_res[:, :min_len, :]

        # Residual connection - element-wise addition
        x = x + x_res  # Standard residual connection

        return x

    def forward(self, x):
        """
        Forward pass for the Mamba model.
        """
        # Apply the Mamba block
        x = self.mambaBlock(x)

        # Apply final output projection
        x = self.out_proj(x)

        return x


In [16]:

class KeywordSpottingModel_with_cls(nn.Module):
    def __init__(self, input_dim, d_model, d_state, d_conv, expand, label_names, num_mamba_layers=1, dropout_rate=0.2):
        super(KeywordSpottingModel_with_cls, self).__init__()
        
        # Initial projection layer
        self.proj = nn.Linear(input_dim, d_model)  
        
        # CLS token: learnable parameter with shape [1, 1, d_model]
        self.cls_token = nn.Parameter(torch.zeros(1, 1, d_model))
        
        
        # Stack multiple Mamba layers with RMSNorm layer
        self.mamba_layers = nn.ModuleList()
        self.layer_norms = nn.ModuleList()

        for _ in range(num_mamba_layers):
            self.mamba_layers.append(Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand))
            self.layer_norms.append(nn.modules.normalization.RMSNorm(d_model))

        # Output layer
        self.fc = nn.Linear(d_model, len(label_names))  
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # # Quantize the input
        # x = self.quant(x)

        # Reshape to [batch_size, num_frames, num_mfcc]
        x = x.permute(0, 2, 1)
        
        # Dequantize before projection to ensure dtype match with weights
        # x = self.dequant(x)
        
        # Project input to d_model dimension
        x = self.proj(x)  
        
        # Re-quantize after projection (optional, based on your quantization strategy)
        # x = self.quant(x)
        
        # Create a CLS token and expand it across the batch dimension
        batch_size = x.size(0)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # Shape: [batch_size, 1, d_model]
        
        # Append the CLS token to the input sequence
        x = torch.cat((x, cls_tokens), dim=1)  # Shape: [batch_size, num_frames + 1, d_model]
        x = x.permute(0, 2, 1)  # Transpose to [batch_size, d_model, num_frames + 1] for Mamba
        
        # Pass through Mamba layers and layer normalization
        for mamba_layer, layer_norm in zip(self.mamba_layers, self.layer_norms):
            x = mamba_layer(x)
            x = layer_norm(x)  # Apply RMSNorm after Mamba layer

        x = self.dropout(x)  # Apply dropout after Mamba layers
        
        # Extract the CLS token output (last token)
        cls_output = x[:, :, -1]  # Shape: [batch_size, d_model]
        
        # # Dequantize before the final fully connected layer
        # cls_output = self.dequant(cls_output)
        
        # Pass through the output layer
        x = self.fc(cls_output)
        
        return x

In [17]:
# Initialize the Mamba model
mamba_model = KeywordSpottingModel_with_cls(input_dim=model_configs['input_dim'], d_model=model_configs['d_model'], d_state=model_configs['d_state'], d_conv=model_configs['d_conv'], expand=model_configs['expand'], label_names=model_configs['label_names'], num_mamba_layers=2).to('cuda')
# mamba_model = Mamba(d_model=model_configs['d_model'], d_state=model_configs['d_state'], d_conv=model_configs['d_conv'], expand=model_configs['expand'])
# Create a dummy input sequence (batch_size=1, sequence_length=10, input_size=1)
# input_sequence = torch.randn(configs['batch_size'], model_configs['input_dim'], model_configs['d_model']).to('cuda')

# # Forward pass through the Mamba model
# output_sequence = mamba_model(input_sequence)
# print(output_sequence)

In [19]:
states = torch.load('best_model_95.pth')

new_state_dict = mamba_model.state_dict()

for state in states:
    print(state)
    

cls_token
proj.weight
proj.bias
mamba_layers.0.A_log
mamba_layers.0.D
mamba_layers.0.in_proj.weight
mamba_layers.0.conv1d.weight
mamba_layers.0.conv1d.bias
mamba_layers.0.x_proj.weight
mamba_layers.0.dt_proj.weight
mamba_layers.0.dt_proj.bias
mamba_layers.0.out_proj.weight
mamba_layers.1.A_log
mamba_layers.1.D
mamba_layers.1.in_proj.weight
mamba_layers.1.conv1d.weight
mamba_layers.1.conv1d.bias
mamba_layers.1.x_proj.weight
mamba_layers.1.dt_proj.weight
mamba_layers.1.dt_proj.bias
mamba_layers.1.out_proj.weight
layer_norms.0.weight
layer_norms.1.weight
fc.weight
fc.bias


  states = torch.load('best_model_95.pth')


In [20]:
for state in new_state_dict:
    print(state)

cls_token
proj.weight
proj.bias
mamba_layers.0.x_proj.weight
mamba_layers.0.x_proj.bias
mamba_layers.0.x_res_proj.weight
mamba_layers.0.x_res_proj.bias
mamba_layers.0.conv1d.weight
mamba_layers.0.conv1d.bias
mamba_layers.0.ssm_model.A
mamba_layers.0.ssm_model.B
mamba_layers.0.ssm_model.C
mamba_layers.0.ssm_model.D
mamba_layers.0.out_proj.weight
mamba_layers.0.out_proj.bias
mamba_layers.1.x_proj.weight
mamba_layers.1.x_proj.bias
mamba_layers.1.x_res_proj.weight
mamba_layers.1.x_res_proj.bias
mamba_layers.1.conv1d.weight
mamba_layers.1.conv1d.bias
mamba_layers.1.ssm_model.A
mamba_layers.1.ssm_model.B
mamba_layers.1.ssm_model.C
mamba_layers.1.ssm_model.D
mamba_layers.1.out_proj.weight
mamba_layers.1.out_proj.bias
layer_norms.0.weight
layer_norms.1.weight
fc.weight
fc.bias


In [21]:
def map_state_dict(state_dict1):
    mapped_state_dict = {}

    # Directly copy common keys
    common_keys = [
        "cls_token",
        "proj.weight",
        "proj.bias",
        "mamba_layers.0.conv1d.weight",
        "mamba_layers.0.conv1d.bias",
        "mamba_layers.0.out_proj.weight",
        "mamba_layers.1.conv1d.weight",
        "mamba_layers.1.conv1d.bias",
        "mamba_layers.1.out_proj.weight",
        "layer_norms.0.weight",
        "layer_norms.1.weight",
        "fc.weight",
        "fc.bias"
    ]

    for key in common_keys:
        if key in state_dict1:
            mapped_state_dict[key] = state_dict1[key]

    # Mamba layer 0 mappings
    if "mamba_layers.0.A_log" in state_dict1:
        mapped_state_dict["mamba_layers.0.ssm_model.A"] = state_dict1["mamba_layers.0.A_log"]
    if "mamba_layers.0.D" in state_dict1:
        mapped_state_dict["mamba_layers.0.ssm_model.D"] = state_dict1["mamba_layers.0.D"]
    if "mamba_layers.0.in_proj.weight" in state_dict1:
        mapped_state_dict["mamba_layers.0.x_res_proj.weight"] = state_dict1["mamba_layers.0.in_proj.weight"]
    if "mamba_layers.0.x_proj.weight" in state_dict1:
        mapped_state_dict["mamba_layers.0.x_proj.weight"] = state_dict1["mamba_layers.0.x_proj.weight"]
    if "mamba_layers.0.dt_proj.weight" in state_dict1:
        mapped_state_dict["mamba_layers.0.ssm_model.B"] = state_dict1["mamba_layers.0.dt_proj.weight"]
    if "mamba_layers.0.dt_proj.bias" in state_dict1:
        mapped_state_dict["mamba_layers.0.ssm_model.C"] = state_dict1["mamba_layers.0.dt_proj.bias"]

    # Mamba layer 1 mappings
    if "mamba_layers.1.A_log" in state_dict1:
        mapped_state_dict["mamba_layers.1.ssm_model.A"] = state_dict1["mamba_layers.1.A_log"]
    if "mamba_layers.1.D" in state_dict1:
        mapped_state_dict["mamba_layers.1.ssm_model.D"] = state_dict1["mamba_layers.1.D"]
    if "mamba_layers.1.in_proj.weight" in state_dict1:
        mapped_state_dict["mamba_layers.1.x_res_proj.weight"] = state_dict1["mamba_layers.1.in_proj.weight"]
    if "mamba_layers.1.x_proj.weight" in state_dict1:
        mapped_state_dict["mamba_layers.1.x_proj.weight"] = state_dict1["mamba_layers.1.x_proj.weight"]
    if "mamba_layers.1.dt_proj.weight" in state_dict1:
        mapped_state_dict["mamba_layers.1.ssm_model.B"] = state_dict1["mamba_layers.1.dt_proj.weight"]
    if "mamba_layers.1.dt_proj.bias" in state_dict1:
        mapped_state_dict["mamba_layers.1.ssm_model.C"] = state_dict1["mamba_layers.1.dt_proj.bias"]

    return mapped_state_dict

In [22]:
new_state_dict = map_state_dict(states)

In [23]:
loaded_model = mamba_model.load_state_dict(new_state_dict, strict=False)

RuntimeError: Error(s) in loading state_dict for KeywordSpottingModel_with_cls:
	size mismatch for mamba_layers.0.x_proj.weight: copying a param with shape torch.Size([111, 272]) from checkpoint, the shape in current model is torch.Size([272, 136]).
	size mismatch for mamba_layers.0.x_res_proj.weight: copying a param with shape torch.Size([544, 136]) from checkpoint, the shape in current model is torch.Size([272, 136]).
	size mismatch for mamba_layers.0.ssm_model.A: copying a param with shape torch.Size([272, 51]) from checkpoint, the shape in current model is torch.Size([51, 51]).
	size mismatch for mamba_layers.0.ssm_model.B: copying a param with shape torch.Size([272, 9]) from checkpoint, the shape in current model is torch.Size([51, 272]).
	size mismatch for mamba_layers.0.ssm_model.C: copying a param with shape torch.Size([272]) from checkpoint, the shape in current model is torch.Size([272, 51]).
	size mismatch for mamba_layers.0.ssm_model.D: copying a param with shape torch.Size([272]) from checkpoint, the shape in current model is torch.Size([272, 272]).
	size mismatch for mamba_layers.1.x_proj.weight: copying a param with shape torch.Size([111, 272]) from checkpoint, the shape in current model is torch.Size([272, 136]).
	size mismatch for mamba_layers.1.x_res_proj.weight: copying a param with shape torch.Size([544, 136]) from checkpoint, the shape in current model is torch.Size([272, 136]).
	size mismatch for mamba_layers.1.ssm_model.A: copying a param with shape torch.Size([272, 51]) from checkpoint, the shape in current model is torch.Size([51, 51]).
	size mismatch for mamba_layers.1.ssm_model.B: copying a param with shape torch.Size([272, 9]) from checkpoint, the shape in current model is torch.Size([51, 272]).
	size mismatch for mamba_layers.1.ssm_model.C: copying a param with shape torch.Size([272]) from checkpoint, the shape in current model is torch.Size([272, 51]).
	size mismatch for mamba_layers.1.ssm_model.D: copying a param with shape torch.Size([272]) from checkpoint, the shape in current model is torch.Size([272, 272]).