# Lazy-loading Pytorch Models

A simple example using PyTorch to demonstrate how to store a model's architecture separately and load weights from storage as needed.  This example will use a basic neural network, but the concept can be extended to larger language models.

## Imports

In [None]:
%pip install torch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cpu

In [None]:
import torch
import torch.nn as nn
import os


## Create a simple model

In [None]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

This example demonstrates the basic concept of separating model architecture from weights. 

In [None]:
def save_model_weights(model, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    torch.save(model.state_dict(), path)

def load_model_weights(model, path):
    model.load_state_dict(torch.load(path))
    return model

In [None]:
# Create and save the model
model = SimpleModel()
save_model_weights(model, "weights/model_weights.pth")

## Loading weights

The get_model() function demonstrates the lazy loading concept:

* It first creates the model architecture (which is fast and lightweight).
* It then loads the weights from storage only when needed.

In [None]:
# Later, when you need to use the model:
def get_model():
    # Create the model architecture (this is fast and lightweight)
    model = SimpleModel()
    
    # Load weights only when needed (this is the potentially slow part)
    return load_model_weights(model, "weights/model_weights.pth")


In a real-world scenario with large language models, you might extend this approach by:

* Using more efficient storage formats (e.g., memory-mapped files).
* Implementing partial loading of weights for specific layers or components.
* Utilizing asynchronous loading to minimize wait times.
* Employing caching strategies to keep frequently used weights in memory.

In [None]:
# Use the model
loaded_model = get_model()
input_data = torch.randn(1, 10)
output = loaded_model(input_data)
print(output)