<a href="https://colab.research.google.com/github/Ahbar1999/mtp-pimsimulator/blob/main/pimsimulator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

In [None]:
class PIMSimulator:
    def __init__(self, num_crossbars=1024, crossbar_size=(128, 128), bits_per_cell=1, weight_bits=8):
        self.num_crossbars = num_crossbars
        self.crossbar_rows, self.crossbar_cols = crossbar_size
        self.bits_per_cell = bits_per_cell
        self.weight_bits = weight_bits

        # Calculate how many weights fit in one crossbar
        self.cells_per_crossbar = self.crossbar_rows * self.crossbar_cols
        self.weights_per_crossbar = self.cells_per_crossbar // self.weight_bits

        # interval write count per crossbar
        self.iwc_counters = np.zeros(num_crossbars, dtype=np.uint32)
        # total write count per crossbar
        self.twc_counters = np.zeros(num_crossbars, dtype=np.uint64)

        # initial logical addr = physical addr mapping
        # addressing granularity is a crossbar
        self.mapping_table = {i: i for i in range(num_crossbars)}
        self.access_count = 0  # Initialize access_count
        self.remapping_interval = 1000 # Placeholder for remapping interval

        print(f"Crossbar config: {crossbar_size}, {self.weights_per_crossbar} weights per crossbar")

    def map_tensor_to_crossbars(self, tensor):
        """Map tensor elements to multiple crossbars based on actual capacity"""
        tensor_size = tensor.numel()  # Total number of elements

        # Calculate how many crossbars this tensor spans
        crossbars_needed = math.ceil(tensor_size / self.weights_per_crossbar)

        # Generate crossbar mapping based on tensor's memory address
        base_crossbar = hash(tensor.data_ptr()) % self.num_crossbars

        crossbar_mappings = []
        remaining_elements = tensor_size

        for i in range(crossbars_needed):
            crossbar_id = (base_crossbar + i) % self.num_crossbars
            # Apply LAID -> PAID translation
            physical_crossbar = self.mapping_table.get(crossbar_id, crossbar_id)

            # Calculate how many elements go to this crossbar
            elements_in_crossbar = min(remaining_elements, self.weights_per_crossbar)

            crossbar_mappings.append({
                'crossbar_id': physical_crossbar,
                'elements': elements_in_crossbar,
                'utilization': elements_in_crossbar / self.weights_per_crossbar
            })

            remaining_elements -= elements_in_crossbar

        return crossbar_mappings

    def log_memory_access(self, tensor, is_write=True, operation="unknown"):
        """Accurate logging considering tensor size and crossbar capacity"""
        if not is_write:
            return

        crossbar_mappings = self.map_tensor_to_crossbars(tensor)

        for mapping in crossbar_mappings:
            crossbar_id = mapping['crossbar_id']
            elements = mapping['elements']

            # Update wear count proportional to actual writes
            # Each element update = multiple cell writes (due to 8-bit quantization)
            writes_per_element = self.weight_bits // self.bits_per_cell  # 8 writes per weight
            total_writes = elements * writes_per_element

            self.iwc_counters[crossbar_id] += total_writes

        self.access_count += 1

        # Check remapping trigger
        if self.access_count % self.remapping_interval == 0:
            self.perform_remapping()

    def perform_remapping(self):
        # Update Total Wear Count (TWC)
        for laid, iwc in enumerate(self.iwc_counters):
            paid = self.mapping_table.get(laid, laid)
            self.twc_counters[paid] += iwc
            self.iwc_counters[laid] = 0  # Reset interval counter

        # Sort crossbars by wear level (TWC)
        wear_sorted_crossbars = sorted(enumerate(self.twc_counters), key=lambda x: x[1])

        # Sort logical crossbars by hotness (recent writes)
        hotness_sorted = sorted(enumerate(self.iwc_counters), key=lambda x: x[1], reverse=True)

        # Rebuild mapping: hottest logical → coolest physical
        new_mapping = {}
        for i, (hot_laid, _) in enumerate(hotness_sorted):
            cool_paid, _ = wear_sorted_crossbars[i]  # Least worn physical crossbar
            new_mapping[hot_laid] = cool_paid

        self.mapping_table = new_mapping


    def get_crossbar_utilization_stats(self):
        """Analyze crossbar utilization patterns"""
        stats = {
            'total_crossbars': self.num_crossbars,
            'weights_per_crossbar': self.weights_per_crossbar,
            'active_crossbars': np.count_nonzero(self.iwc_counters),
            'avg_wear': np.mean(self.twc_counters),
            'max_wear': np.max(self.twc_counters),
            'wear_imbalance': np.max(self.twc_counters) / (np.min(self.twc_counters) + 1)
        }
        return stats

In [None]:
class PIMTrainingHook:
    def __init__(self, pim_simulator):
        self.pim_sim = pim_simulator
        self.hooks = []
        self.model_params = {} # To store model parameters by name

    def register_hooks(self, model):
        """Register hooks to capture memory accesses"""
        for name, module in model.named_modules():
            if hasattr(module, 'weight') and module.weight is not None:
                self.model_params[name + '.weight'] = module.weight # Store parameter

                hook = module.register_forward_hook(
                    lambda module, input, output, name=name:
                    self.forward_hook(module, input, output, name)
                )
                self.hooks.append(hook)

                # Register parameter(backward prop) update hook
                if module.weight.requires_grad:
                    hook = module.weight.register_hook(
                        lambda grad, name=name:
                        self.parameter_hook(grad, name)
                    )
                    self.hooks.append(hook)

    def forward_hook(self, module, input, output, name):
        """Log forward pass memory accesses"""
        if hasattr(module, 'weight'):
            self.pim_sim.log_memory_access(
                module.weight, is_write=False, operation=f"forward_{name}"
            )
        if isinstance(output, torch.Tensor):
            self.pim_sim.log_memory_access(
                output, is_write=True, operation=f"output_{name}"
            )

    def parameter_hook(self, grad, name):
        """Log parameter update with accurate crossbar mapping"""
        if grad is not None:
            param = self.get_parameter_by_name(name)  # Get corresponding parameter tensor

            # Log both gradient computation and parameter update
            self.pim_sim.log_memory_access(grad, is_write=False, operation=f"grad_read_{name}")
            self.pim_sim.log_memory_access(param, is_write=True, operation=f"param_write_{name}")

            # Print mapping info for debugging
            mappings = self.pim_sim.map_tensor_to_crossbars(param)
            total_crossbars = len(mappings)
            total_writes = sum(m['elements'] for m in mappings) * self.pim_sim.weight_bits

            print(f"Parameter {name}: {param.shape} -> {total_crossbars} crossbars, {total_writes} total writes")

    def get_parameter_by_name(self, name):
        """Retrieve parameter tensor by name"""
        # The name in the hook is just the module name, need to append '.weight'
        param_name = name + '.weight'
        return self.model_params.get(param_name)

    def cleanup(self):
        """Remove all hooks"""
        for hook in self.hooks:
            hook.remove()

In [None]:
# Load MNIST digits dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

100%|██████████| 9.91M/9.91M [00:00<00:00, 51.3MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.67MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.6MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.24MB/s]


In [None]:
# Initialize PIM simulator and hooks
pim_simulator = PIMSimulator(num_crossbars=512, crossbar_size=(128, 128))
pim_hook = PIMTrainingHook(pim_simulator)

# Your neural network
model = torch.nn.Sequential(
    nn.Flatten(), # Add this line to flatten the input
    torch.nn.Linear(784, 256),
    torch.nn.ReLU(),
    torch.nn.Linear(256, 128),
    torch.nn.ReLU(),
    torch.nn.Linear(128, 10)
)

# Register PIM monitoring hooks
pim_hook.register_hooks(model)

# Training loop with PIM simulation
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# loss function
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(10):
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()

        # Forward pass (monitored by hooks)
        output = model(data)
        loss = criterion(output, target)

        # Backward pass (monitored by hooks)
        loss.backward()

        # Parameter updates (monitored by hooks)
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch: {epoch}, Batch: {batch_idx}")
            print(f"Max crossbar wear: {pim_simulator.twc_counters.max()}")
            print(f"Wear distribution std: {pim_simulator.twc_counters.std():.2f}")

# Cleanup
pim_hook.cleanup()

# Analysis
print("\nFinal TIWL Statistics:")
print(f"Total memory accesses: {pim_simulator.access_count}")
print(f"Average crossbar wear: {pim_simulator.twc_counters.mean():.2f}")
print(f"Wear imbalance (max/min): {pim_simulator.twc_counters.max() / (pim_simulator.twc_counters.min() + 1):.2f}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Parameter 5: torch.Size([10, 128]) -> 1 crossbars, 10240 total writes
Parameter 3: torch.Size([128, 256]) -> 16 crossbars, 262144 total writes
Parameter 1: torch.Size([256, 784]) -> 98 crossbars, 1605632 total writes
Parameter 5: torch.Size([10, 128]) -> 1 crossbars, 10240 total writes
Parameter 3: torch.Size([128, 256]) -> 16 crossbars, 262144 total writes
Parameter 1: torch.Size([256, 784]) -> 98 crossbars, 1605632 total writes
Parameter 5: torch.Size([10, 128]) -> 1 crossbars, 10240 total writes
Parameter 3: torch.Size([128, 256]) -> 16 crossbars, 262144 total writes
Parameter 1: torch.Size([256, 784]) -> 98 crossbars, 1605632 total writes
Parameter 5: torch.Size([10, 128]) -> 1 crossbars, 10240 total writes
Parameter 3: torch.Size([128, 256]) -> 16 crossbars, 262144 total writes
Parameter 1: torch.Size([256, 784]) -> 98 crossbars, 1605632 total writes
Parameter 5: torch.Size([10, 128]) -> 1 crossbars, 10240 total writ