# Streaming Pytorch Models

This example demonstrates a basic approach to loading model weights from files at initialization time, rather than keeping them in memory as part of the model's state_dict. While this specific implementation loads the entire weights into memory when the model is created, it serves as a starting point for more advanced streaming techniques.

## Imports

In [None]:
%pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu
%pip install numpy==1.24.4

In [None]:
import torch
import torch.nn as nn
import numpy as np
import os


## Simple streaming model definition

The StreamedLinear class is a custom PyTorch module that mimics a linear layer but loads its weights from files:

* It uses np.memmap to memory-map the weight and bias files, allowing direct access to the data on disk without loading it entirely into memory.
* The weights and biases are then converted to PyTorch tensors and registered as buffers using register_buffer. This makes them part of the module's state without being considered parameters.
* The forward method performs a linear operation using these loaded weights and biases.

In [None]:
class StreamedLinear(nn.Module):
    def __init__(self, in_features, out_features, weight_file, bias_file):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        
        # Open memory-mapped files
        self.weight_mmap = np.memmap(weight_file, dtype='float32', mode='r', shape=(out_features, in_features))
        self.bias_mmap = np.memmap(bias_file, dtype='float32', mode='r', shape=(out_features,))

        # Register buffers for the weights and biases
        self.register_buffer("weight", torch.from_numpy(self.weight_mmap[:]).clone())
        self.register_buffer("bias", torch.from_numpy(self.bias_mmap[:]).clone())

    def forward(self, input):
        return nn.functional.linear(input, self.weight, self.bias)


The StreamedModel class is a simple wrapper around StreamedLinear, representing a model with a single linear layer using streamed weights.

In [None]:
class StreamedModel(nn.Module):
    def __init__(self, in_features, out_features, weight_file, bias_file):
        super().__init__()
        self.linear = StreamedLinear(in_features, out_features, weight_file, bias_file)

    def forward(self, x):
        return self.linear(x)

## Utility Functions and Setup

This section defines a function to save weights and biases to files and sets up a directory to store these files.

In [None]:
# Function to save weights to file
def save_weights_to_file(weight, bias, weight_file, bias_file):
    weight.detach().numpy().tofile(weight_file)
    bias.detach().numpy().tofile(bias_file)

In [None]:
# Create a directory for weight files
os.makedirs('streaming', exist_ok=True)
weight_file = 'streaming/weight.bin'
bias_file = 'streaming/bias.bin'

# Model Creation and Weight Saving

In [None]:
# Create and save a regular model
in_features, out_features = 10, 5
regular_model = nn.Linear(in_features, out_features)
save_weights_to_file(regular_model.weight, regular_model.bias, weight_file, bias_file)


In [None]:
# Create a streamed model
streamed_model = StreamedModel(in_features, out_features, weight_file, bias_file)

## Evaluation

This approach has several advantages:

* Memory Efficiency: The full weight tensors are not loaded into RAM, only the parts needed for the current computation.
* Disk I/O Optimization: Memory-mapped files are optimized for fast random access, which is suitable for neural network computations.
* Scalability: This method can be extended to very large models that wouldn't fit in RAM.

However, there are also some considerations:

* Performance: There might be some overhead in converting mmap arrays to PyTorch tensors during each forward pass.
* Complexity: This approach adds complexity to your model implementation.
* Storage Format: You need to ensure your weights are stored in a compatible format on disk.

In [None]:
# Run inference
input_data = torch.randn(1, in_features)
output_regular = regular_model(input_data)
output_streamed = streamed_model(input_data)

print("Regular model output:", output_regular)
print("Streamed model output:", output_streamed)
print("Outputs are close:", torch.allclose(output_regular, output_streamed))

# Print number of parameters
num_params_regular = sum(p.numel() for p in regular_model.parameters())
num_params_streamed = sum(p.numel() for p in streamed_model.parameters())
print(f"Number of parameters in regular model: {num_params_regular}")
print(f"Number of parameters in streamed model: {num_params_streamed}")

# Verify that the weights are the same
print("Weights are close:", torch.allclose(regular_model.weight, streamed_model.linear.weight))
print("Biases are close:", torch.allclose(regular_model.bias, streamed_model.linear.bias))

For very large language models, you might need to extend this concept further:

* Implement caching mechanisms to keep frequently used weights in memory.
* Use more sophisticated sharding strategies to distribute weights across multiple files or even multiple machines.
* Optimize the storage format for quick access to specific layers or attention heads.