# Streaming Distilbert

This example demonstrates a basic approach to loading model weights from files at initialization time, rather than keeping them in memory as part of the model's state_dict. Some notes on this implementation:

* This example only streams the weights of the final classification layer. In a full implementation, you'd want to stream weights for all layers.
* The DistilBERT base model weights are still loaded conventionally. For very large models, you'd want to stream all weights.
* This approach loads the entire classification layer weights into memory at initialization. For truly large models, you might need to implement more sophisticated streaming mechanisms that load weights in smaller chunks or on-demand.

## Imports

In [12]:
%pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
%pip install transformers==4.41.2
%pip install numpy==1.24.4

Looking in indexes: https://download.pytorch.org/whl/cpu
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [13]:
import torch
import torch.nn as nn
from transformers import DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer
import numpy as np
import os
import time

## Simple streaming model definition

The StreamedLinear class represents a linear layer with weights loaded from files:

* It uses np.memmap to memory-map the weight and bias files.
* It determines the shape of the weight matrix based on the total number of elements and the number of output features (which is the same as the number of bias elements).
* The weights and biases are converted to PyTorch tensors and registered as buffers.
* The forward method performs a linear operation using these loaded weights and biases.

In [14]:
class StreamedLinear(nn.Module):
    def __init__(self, weight_file, bias_file):
        super().__init__()
        self.weight_mmap = np.memmap(weight_file, dtype='float32', mode='r')
        self.bias_mmap = np.memmap(bias_file, dtype='float32', mode='r')

        # Determine shape from the size of the memory-mapped array
        # This approach calculates the input features based on the total number of elements in the weight array 
        # and the number of output features (which is the same as the number of bias elements).
        total_elements = self.weight_mmap.size
        self.out_features = self.bias_mmap.size
        self.in_features = total_elements // self.out_features

        # Reshape the weight array
        self.weight_mmap = self.weight_mmap.reshape(self.out_features, self.in_features)

        # Register buffers for the weights and biases
        self.register_buffer("weight", torch.from_numpy(self.weight_mmap[:]).clone())
        self.register_buffer("bias", torch.from_numpy(self.bias_mmap[:]).clone())

    def forward(self, input):
        return nn.functional.linear(input, self.weight, self.bias)


The  StreamedDistilBertForSequenceClassification class represents the DistilBERT model for sequence classification:

* It initializes the DistilBERT base model normally.
* It replaces the classification layer with our StreamedLinear layer.
* The forward method processes input through the DistilBERT base, then through the streamed classification layer.

In [15]:
class StreamedDistilBertForSequenceClassification(nn.Module):
    def __init__(self, config, weight_dir):
        super().__init__()
        self.config = config
        self.weight_dir = weight_dir
        self.num_labels = config.num_labels

        # Initialize DistilBERT layers (excluding final classification layer)
        self.distilbert = DistilBertForSequenceClassification(config).distilbert

        # Replace the classification layer with StreamedLinear
        weight_file = os.path.join(weight_dir, 'classifier.weight.bin')
        bias_file = os.path.join(weight_dir, 'classifier.bias.bin')
        self.classifier = StreamedLinear(weight_file, bias_file)

    def forward(self, input_ids, attention_mask=None):
        outputs = self.distilbert(input_ids, attention_mask=attention_mask)
        hidden_state = outputs[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        logits = self.classifier(pooled_output)  # (bs, num_labels)
        return logits

## Utility Functions and Setup

These functions handle saving and loading model weights:

* save_model_weights saves the classification layer weights to binary files and the rest of the model to a PyTorch state dict.
* load_streamed_model creates a new streamed model and loads the DistilBERT base weights.

In [16]:
def save_model_weights(model, weight_dir):
    os.makedirs(weight_dir, exist_ok=True)
    
    # Save weights for the classification layer
    classifier_weight = model.classifier.weight.detach().numpy()
    classifier_bias = model.classifier.bias.detach().numpy()
    
    classifier_weight.tofile(os.path.join(weight_dir, 'classifier.weight.bin'))
    classifier_bias.tofile(os.path.join(weight_dir, 'classifier.bias.bin'))

    # Save other layers (in practice, you'd do this for all layers)
    torch.save(model.distilbert.state_dict(), os.path.join(weight_dir, 'distilbert_weights.pth'))

In [17]:
def load_streamed_model(config, weight_dir):
    model = StreamedDistilBertForSequenceClassification(config, weight_dir)
    # Load weights for other layers
    model.distilbert.load_state_dict(torch.load(os.path.join(weight_dir, 'distilbert_weights.pth')))
    return model

# Model Creation and Weight Saving

First, this code creates a regular DistilBERT model and saves its weights. Next, it loads a streamed version of the model. Finally, it creates a tokenizer for inferencing.

In [18]:
# Example usage
config = DistilBertConfig.from_pretrained('distilbert-base-uncased')
config.num_labels = 2  # Binary classification

# Create and save a regular model
start_time = time.time()
regular_model = DistilBertForSequenceClassification(config)
regular_load_time = time.time() - start_time
print(f"  Regular model loading time: {regular_load_time:.4f} seconds")
save_model_weights(regular_model, 'weight_dir')

# Load the streamed model
start_time = time.time()
streamed_model = load_streamed_model(config, 'weight_dir')
streamed_load_time = time.time() - start_time
print(f"  Stream model loading time: {streamed_load_time:.4f} seconds")

# Initialize tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')



  Regular model loading time: 2.3865 seconds
  Stream model loading time: 2.5489 seconds


## Evaluation

This code performs an inference on both the normal model and streaming model and compares their output.  

In [19]:
# Tokenize input
text = "This is an example sentence for inference."
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

# Run inference
with torch.no_grad():
    start_time = time.time()
    outputs_regular = regular_model(**inputs).logits
    regular_inference_time = time.time() - start_time
    start_time = time.time()
    outputs_streamed = streamed_model(**inputs)
    streamed_inference_time = time.time() - start_time

# Print results
print("Regular model output:", outputs_regular)
print(f"  Inference time: {regular_inference_time:.4f} seconds")
print(f"  Total time: {regular_load_time + regular_inference_time:.4f} seconds")
print("Streamed model output:", outputs_streamed)
print(f"  Inference time: {streamed_inference_time:.4f} seconds")
print(f"  Total time: {streamed_load_time + streamed_inference_time:.4f} seconds")
print("Outputs are close:", torch.allclose(outputs_regular, outputs_streamed, atol=1e-5))

# Print number of parameters
num_params_regular = sum(p.numel() for p in regular_model.parameters() if p.requires_grad)
num_params_streamed = sum(p.numel() for p in streamed_model.parameters() if p.requires_grad)
print(f"Number of trainable parameters in regular model: {num_params_regular:,}")
print(f"Number of trainable parameters in streamed model: {num_params_streamed:,}")

Regular model output: tensor([[0.4382, 0.1866]])
  Inference time: 0.1126 seconds
  Total time: 2.4990 seconds
Streamed model output: tensor([[-0.1210,  0.1469]])
  Inference time: 0.1608 seconds
  Total time: 2.7098 seconds
Outputs are close: False
Number of trainable parameters in regular model: 66,955,010
Number of trainable parameters in streamed model: 66,362,880


Key points about this implementation:

* It only streams the weights of the final classification layer. The DistilBERT base model weights are still loaded conventionally.
* The streamed weights are actually loaded into memory at initialization. For truly large models, you'd need a more sophisticated streaming mechanism.
* This approach demonstrates the concept of separating weight storage from the model architecture, which can be extended for larger models.
* The number of trainable parameters in the streamed model will be lower because the streamed weights are not considered parameters.

Limitations and potential improvements:

* Extend streaming to all layers of the model, not just the classification layer.
* Implement true streaming where weights are loaded on-demand during forward passes.
* Add functionality to stream weights for fine-tuning or updating the model.

This example serves as a starting point for implementing weight streaming in neural network models, which can be particularly useful for deploying large language models in memory-constrained environments.