# Memory Efficient Linear Layer Implementation

This notebook implements a memory-efficient linear layer that processes data in batches to reduce VRAM usage during forward and backward passes.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Callable, Tuple, Any
import math

In [None]:
def split_into_batches(tensor: torch.Tensor, batch_size: int):
    """Split a tensor into smaller batches."""
    return torch.split(tensor, batch_size)

In [None]:
class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X: torch.Tensor, linear: nn.Linear, 
               labels: torch.Tensor, transform_fn: Callable,
               batch_size: int = 2) -> torch.Tensor:
        """
        Forward pass that processes data in batches to save memory.
        
        Args:
            X: Input tensor
            linear: Linear layer
            labels: Target labels
            transform_fn: Function to transform outputs (e.g., cross entropy)
            batch_size: Size of mini-batches for processing
        """
        # Save tensors needed for backward pass
        ctx.save_for_backward(X, labels)
        ctx.linear = linear
        ctx.transform_fn = transform_fn
        ctx.batch_size = batch_size
        
        # Split input into smaller batches
        batches = split_into_batches(X, batch_size)
        label_batches = split_into_batches(labels, batch_size)
        
        # Process each batch and accumulate results
        total_loss = 0
        for batch, batch_labels in zip(batches, label_batches):
            # Forward pass for current batch
            loss = transform_fn(batch, linear, batch_labels)
            total_loss += loss * len(batch)  # Scale by batch size
            
        # Average the loss
        return total_loss / len(X)

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None, None, None]:
        """Backward pass that computes gradients in a memory-efficient way."""
        X, labels = ctx.saved_tensors
        linear = ctx.linear
        transform_fn = ctx.transform_fn
        batch_size = ctx.batch_size
        
        # Split into batches
        batches = split_into_batches(X, batch_size)
        label_batches = split_into_batches(labels, batch_size)
        
        # Initialize gradient accumulators
        grad_X = torch.zeros_like(X)
        grad_weight = torch.zeros_like(linear.weight)
        grad_bias = torch.zeros_like(linear.bias) if linear.bias is not None else None
        
        # Process each batch
        for i, (batch, batch_labels) in enumerate(zip(batches, label_batches)):
            # Enable grad tracking for this batch
            batch_tensor = batch.detach().requires_grad_()
            
            # Forward pass with grad tracking
            with torch.enable_grad():
                loss = transform_fn(batch_tensor, linear, batch_labels)
                
            # Backward pass for this batch
            batch_grad = torch.autograd.grad(
                loss, 
                [batch_tensor, linear.weight, linear.bias] if linear.bias is not None else [batch_tensor, linear.weight],
                grad_output
            )
            
            # Accumulate gradients
            start_idx = i * batch_size
            end_idx = start_idx + len(batch)
            grad_X[start_idx:end_idx] = batch_grad[0]
            grad_weight += batch_grad[1]
            if linear.bias is not None:
                grad_bias += batch_grad[2]
        
        return grad_X, None, None, None, None

In [None]:
# Example transformation functions
def cross_entropy_transform(batch: torch.Tensor, linear: nn.Linear, labels: torch.Tensor) -> torch.Tensor:
    """Cross entropy loss transformation."""
    logits = linear(batch).float()
    return F.cross_entropy(logits.view(-1, logits.shape[-1]), labels.view(-1), reduction='mean')

def mse_transform(batch: torch.Tensor, linear: nn.Linear, labels: torch.Tensor) -> torch.Tensor:
    """Mean squared error transformation."""
    output = linear(batch).float()
    return F.mse_loss(output, labels, reduction='mean')

def custom_transform(batch: torch.Tensor, linear: nn.Linear, labels: torch.Tensor) -> torch.Tensor:
    """Custom transformation example (L1 loss with scaling)."""
    output = linear(batch).float()
    return F.l1_loss(output, labels, reduction='mean') * 0.5

In [None]:
# Test the implementation
def test_memory_efficient_linear():
    # Parameters
    batch_size = 4
    seq_len = 512
    hidden_dim = 1024
    vocab_size = 32000  # Smaller for testing
    
    # Create test data
    X = torch.randn(batch_size, seq_len, hidden_dim, device='cuda')
    labels = torch.randint(0, vocab_size, (batch_size, seq_len), device='cuda')
    
    # Create linear layer
    linear = nn.Linear(hidden_dim, vocab_size, bias=True).cuda()
    
    # Test with different transformation functions
    transforms = {
        'cross_entropy': cross_entropy_transform,
        'mse': mse_transform,
        'custom': custom_transform
    }
    
    for name, transform_fn in transforms.items():
        print(f"\nTesting {name} transformation:")
        
        # Memory efficient forward pass
        efficient_output = MemoryEfficientLinear.apply(X, linear, labels, transform_fn, 2)
        
        # Compute gradients
        efficient_output.backward()
        
        print(f"Output shape: {efficient_output.shape}")
        print(f"Output value: {efficient_output.item():.4f}")

if __name__ == "__main__":
    test_memory_efficient_linear()