In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
from torch.profiler import profile, record_function, ProfilerActivity
from torchvision.models import resnet18, ResNet18_Weights  # Example pre-trained model
from torch.utils.tensorboard import SummaryWriter
from collections import OrderedDict


In [6]:
import mlx.core as mx
import mlx.nn as nn

class CachedConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.cache = None  # To store cached activations

    def forward(self, x):
        batch_size, channels, height, width = x.shape

        # Initialize cache if empty
        if self.cache is None:
            self.cache = mx.zeros((batch_size, channels, height, width))

        # Compute only missing regions (where cache is zero)
        needs_computation = mx.where(self.cache == 0, 1, 0)
        computed_values = self.conv(x * needs_computation)

        # Update cache: Merge cached & computed activations
        self.cache = mx.where(needs_computation, computed_values, self.cache)

        return self.cache


In [7]:
def replace_convolution_layers(model):
    for name, module in model.named_children():
        if isinstance(module, nn.Conv2d):
            new_layer = CachedConv2D(
                module.in_channels, module.out_channels, module.kernel_size[0], module.stride[0], module.padding[0]
            )
            new_layer.conv.weight = module.weight  # Copy weights
            setattr(model, name, new_layer)
        else:
            replace_convolution_layers(module)  # Recursive replacement

    return model


In [8]:
def profile_model(model, input_tensor, log_dir="logs"):
    writer = SummaryWriter(log_dir)
    
    model.eval()
    with torch.no_grad():
        with profile(
            activities=[ProfilerActivity.CPU],  # PyTorch Profiler Setup (Logging to TensorBoard)
            record_shapes=True,  # Captures tensor shapes
            profile_memory=True,  # Tracks memory usage
            with_stack=True,  # Captures function call stack
            on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir) # Saves data for TensorBoard
            ) as prof:
            output = model(input_tensor)

    
    # Save profiling data for TensorBoard
    writer.add_graph(model, input_tensor)
    writer.close()

## Driver

In [9]:
# Load Pre-Trained Model
original_model = resnet18(weights=ResNet18_Weights)

# Set Device
device = torch.device("cpu")

# Create Modified Model with Cached Convolution
cached_model = replace_convolution_layers(original_model)
cached_model.to(device)

# Generate Test Input
test_input = torch.randn(1, 3, 224, 224).to(device)

# Ensure the Modified Model Produces Identical Output
original_output = original_model(test_input)
cached_output = cached_model(test_input)

assert torch.allclose(original_output, cached_output, atol=1e-5), "Mismatch in outputs!"

# Profile the Model
profile_model(cached_model, test_input, log_dir="logs/cached_conv")
profile_model(original_model, test_input, log_dir="logs/original_conv")