In [3]:
### Get the MAMBA form GitHub "https://github.com/state-spaces/mamba" ###
!git clone https://github.com/state-spaces/mamba.git
%cd mamba
!pip install .
!pip install datasets transformers torchprofile
!pip install triton==2.0.0
!pip install thop

fatal: destination path 'mamba' already exists and is not an empty directory.
/content/mamba/mamba
Processing /content/mamba/mamba
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: mamba_ssm
  Building wheel for mamba_ssm (pyproject.toml) ... [?25l[?25hdone
  Created wheel for mamba_ssm: filename=mamba_ssm-2.2.4-cp310-cp310-linux_x86_64.whl size=323653202 sha256=c6edff068928b4ceacc6e820a163908c02294fe01dbcc05e95ed4530a5f81e77
  Stored in directory: /tmp/pip-ephem-wheel-cache-pgyxidbz/wheels/c8/52/57/e1d47a0b5671ea2c7e3d2105232f3861a6933be5d33b6abd22
Successfully built mamba_ssm
Installing collected packages: mamba_ssm
  Attempting uninstall: mamba_ssm
    Found existing installation: mamba-ssm 2.2.4
    Uninstalling mamba-ssm-2.2.4:
      Successfully uninstalled mamba-ssm-2.2.4
Successfully installed mamba_ssm-2.2.4


In [10]:
### Import necessary libraries ###
import os
import torch
import torch.nn.utils.prune as prune
import torch.nn as nn
import torch.optim as optim
import time
from mamba_ssm.models.mixer_seq_simple import MambaLMHeadModel
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from torchprofile import profile_macs
from thop import profile
# Step 2: Set environment variable to try disabling Triton optimizations
os.environ["DISABLE_TRITON"] = "1"
# Load the pre-trained MAMBA model
model = MambaLMHeadModel.from_pretrained("state-spaces/mamba-2.8b-slimpj")

  return torch.load(resolved_archive_file, map_location=mapped_device)


In [11]:
### Set up the test dataset ###
# Initialize the tokenizer using GPT-2's tokenizer (since Mamba tokenizer is not available)
tokenizer = AutoTokenizer.from_pretrained("gpt2")  # Using GPT-2 tokenizer as an alternative

# Add a pad token to the tokenizer (since GPT-2 doesn't have one by default)
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

# Step 3: Load the WikiText-2 dataset
# Load the WikiText-2 dataset
dataset = load_dataset("wikitext", "wikitext-2-v1")

# Only use the training split (subset of the data for simplicity)
texts = dataset['train']['text'][:1000]  # Use a smaller subset for quick testing or use the whole dataset

# Step 4: Define a custom PyTorch Dataset class
class WikiTextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=128):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors="pt"
        )
        input_ids = encoding["input_ids"].squeeze()
        attention_mask = encoding["attention_mask"].squeeze()
        return input_ids, attention_mask

# Step 5: Create the Dataset and DataLoader
wikitext_dataset = WikiTextDataset(texts, tokenizer)
test_loader = DataLoader(wikitext_dataset, batch_size=1, shuffle=False)

In [23]:
# Step 6: Evaluate the original Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)

# Set model to evaluation mode
model.eval()

# Initialize evaluation metrics
criterion = nn.CrossEntropyLoss()
total_loss = 0
total = 0
correct = 0
total_time = 0

# Calculate FLOPs using thop
input_sample = torch.randint(0, tokenizer.vocab_size, (1, 128)).to(device)
macs, params = profile(model, inputs=(input_sample,))
flops = 2 * macs  # FLOPs is typically 2 * MACs for neural networks

print(f"FLOPs: {flops:.2e}")
print(f"Number of Parameters: {params}")

# Disable gradient calculations during evaluation to save memory
with torch.no_grad():
    for batch_idx, (input_ids, attention_mask) in enumerate(test_loader):
        # Move the inputs to the GPU if available
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # Measure the time taken for the forward pass (inference time)
        start_time = time.time()

        # Forward pass: compute predictions (without attention mask)
        outputs = model(input_ids)
        logits = outputs.logits  # Assuming the output contains logits

        end_time = time.time()
        total_time += (end_time - start_time)

        # Compute the loss
        loss = criterion(logits.view(-1, logits.size(-1)), input_ids.view(-1))
        total_loss += loss.item()

        # For testing accuracy, find the predicted tokens
        predictions = torch.argmax(logits, dim=-1)

        # Calculate correct predictions
        correct += (predictions == input_ids).sum().item()
        total += input_ids.numel()

        # Skip profiling of MACs for now due to Triton incompatibility issues

# Calculate average loss and accuracy for the test set
average_test_loss = total_loss / len(test_loader)
accuracy = 100 * correct / total
# average_inference_time = total_time / len(test_loader)

# Print the evaluation results
print(f"Test Loss: {average_test_loss:.4f}")
print(f"Test Accuracy: {accuracy:.2f}%")
# print(f"Average Inference Time per Batch: {average_inference_time:.4f} seconds")

[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
FLOPs: 2.64e+11
Number of Parameters: 1032130560.0
Test Loss: 4.2953
Test Accuracy: 63.47%


In [25]:
# Step 7: Prune the model using L1 unstructured pruning
# Make a copy of the model to create a pruned version
model_pruned = model

# Apply L1 unstructured pruning to Linear and Conv1d layers
pruned_layers_count = 0
for name, module in model_pruned.named_modules():
    # Only prune Linear and Conv1d layers
    if isinstance(module, (nn.Linear, nn.Conv1d)):
        prune.l1_unstructured(module, name='weight', amount=0.4)
        pruned_layers_count += 1

# Print the total number of layers that have been pruned
print(f"Total number of layers pruned: {pruned_layers_count}")

# Step 8: Make pruning permanent
for name, module in model_pruned.named_modules():
    if isinstance(module, (nn.Linear, nn.Conv1d)):
        prune.remove(module, 'weight')

# Step 9: Save the pruned model
torch.save(model_pruned.state_dict(), 'model_pruned.pth')
print("Pruned model has been saved as 'model_pruned.pth'.")

# Step 10: Evaluate FLOPs Before and After Pruning
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_pruned.to(device)

# Prepare an input sample for FLOPs calculation
input_sample = torch.randint(0, tokenizer.vocab_size, (1, 128)).to(device)

# # Calculate FLOPs using thop for the original model
macs_original, params_original = profile(model, inputs=(input_sample,))
flops_original = 2 * macs_original  # FLOPs is typically 2 * MACs for neural networks

# Calculate FLOPs using thop for the pruned model
macs_pruned, params_pruned = profile(model_pruned, inputs=(input_sample,))
flops_pruned = 2 * macs_pruned  # FLOPs is typically 2 * MACs for neural networks

# Calculate Sparsity of the Pruned Model
total_params_pruned = 0
total_zero_params_pruned = 0

for name, module in model_pruned.named_modules():
    if hasattr(module, 'weight'):
        weight = module.weight.detach().cpu().numpy()
        total_params_pruned += weight.size
        total_zero_params_pruned += (weight == 0).sum()

sparsity = total_zero_params_pruned / total_params_pruned
# print(f"Pruned Model - Sparsity: {sparsity:.2%}")

# Calculate Effective FLOPs After Pruning
effective_flops = (1 - sparsity) * flops_original
print(f"Pruned Model - FLOPs: {effective_flops:.2e}")
print(f"Pruned Model - Number of Parameters: {params_pruned}")

# Step 11: Evaluate the Pruned Model
# Set pruned model to evaluation mode
model_pruned.eval()

# Initialize evaluation metrics for pruned model
criterion = nn.CrossEntropyLoss()
total_loss = 0
total = 0
correct = 0
total_time = 0

# Disable gradient calculations during evaluation to save memory
with torch.no_grad():
    for batch_idx, (input_ids, attention_mask) in enumerate(test_loader):
        # Move the inputs to the GPU if available
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)

        # Measure the time taken for the forward pass (inference time)
        start_time = time.time()

        # Forward pass: compute predictions (without attention mask)
        outputs = model_pruned(input_ids)
        logits = outputs.logits  # Assuming the output contains logits

        end_time = time.time()
        total_time += (end_time - start_time)

        # Compute the loss
        loss = criterion(logits.view(-1, logits.size(-1)), input_ids.view(-1))
        total_loss += loss.item()

        # For testing accuracy, find the predicted tokens
        predictions = torch.argmax(logits, dim=-1)

        # Calculate correct predictions
        correct += (predictions == input_ids).sum().item()
        total += input_ids.numel()

# Calculate average loss and accuracy for the test set with pruned model
average_test_loss = total_loss / len(test_loader)
accuracy = 100 * correct / total
# average_inference_time = total_time / len(test_loader)

# Print the evaluation results for the pruned model
print(f"Pruned Model - Test Loss: {average_test_loss:.4f}")
print(f"Pruned Model - Test Accuracy: {accuracy:.2f}%")
# print(f"Pruned Model - Average Inference Time per Batch: {average_inference_time:.4f} seconds")

Total number of layers pruned: 321
Pruned model has been saved as 'model_pruned.pth'.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>.
Pruned Model - FLOPs: 1.58e+11
Pruned Model - Number of Parameters: 1032130560.0
Pruned Model - Test Loss: 4.2953
Pruned Model - Test Accuracy: 63.47%


In [16]:
print(model)

MambaLMHeadModel(
  (backbone): MixerModel(
    (embedding): Embedding(50280, 2560)
    (layers): ModuleList(
      (0-63): 64 x Block(
        (norm): RMSNorm()
        (mixer): Mamba(
          (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
          (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
          (act): SiLU()
          (x_proj): Linear(in_features=5120, out_features=192, bias=False)
          (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
          (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
        )
      )
    )
    (norm_f): RMSNorm()
  )
  (lm_head): Linear(in_features=2560, out_features=50280, bias=False)
)
