Evaluating Pytorch saved tensor API

In [1]:
import torch
import torch.nn as nn
from typing import Optional, Union, Any, Iterable

In [2]:
class MLP(nn.Module):
    """
    Basic MLP (multi-layer perceptron) layer with optional Dropout.
    """

    def __init__(
        self,
        d_model: int,
        act_fn: nn.Module,
        dropout_prob: Optional[float] = None,
        device: Optional[Union[str, torch.device]] = None,
        dtype: Optional[torch.dtype] = None,
    ) -> None:
        super().__init__()
        self.d_model = d_model
        self.act_fn = act_fn
        self.dropout_prob = dropout_prob
        factory_kwargs = {"device": device, "dtype": dtype}

        self.lin_0 = nn.Linear(self.d_model, 4 * self.d_model, **factory_kwargs)
        self.lin_1 = nn.Linear(4 * self.d_model, self.d_model, **factory_kwargs)
        self.dropout = nn.Dropout(self.dropout_prob) if self.dropout_prob else None

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        x = self.lin_0(inputs)
        x = self.act_fn(x)
        x = self.lin_1(x)
        if self.dropout is not None:
            x = self.dropout(x)
        return x

In [3]:
class AllocatedMemContext:
    def __init__(self) -> None:
        # Ensure CUDA libraries are loaded:
        torch.cuda.current_blas_handle()

        self.before: dict[str, int] = {}
        self.after: dict[str, int] = {}
        self.delta: dict[str, int] = {}

    def _get_mem_dict(self) -> dict[str, int]:
        # Only need `allocated_bytes.all`-prefixed keys here
        key_prefix = "allocated_bytes.all."
        return {
            k.replace(key_prefix, ""): v
            for k, v in torch.cuda.memory_stats().items()
            if key_prefix in k
        }

    def __enter__(self) -> "AllocatedMemContext":
        self.before = self._get_mem_dict()
        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        self.after = self._get_mem_dict()
        self.delta = {k: v - self.before[k] for k, v in self.after.items()}

In [4]:
class SavedTensorContext:
    def __init__(
        self,
        ignored_tensors: Optional[Iterable[torch.Tensor]] = None,
    ) -> None:
        self._ignored_data_ptrs = (
            set()
            if ignored_tensors is None
            else {t.untyped_storage().data_ptr() for t in ignored_tensors}
        )

        self.saved_tensor_dict = torch.utils.weak.WeakTensorKeyDictionary()

        def pack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
            data_ptr = saved_tensor.untyped_storage().data_ptr()
            if data_ptr not in self._ignored_data_ptrs:
                self.saved_tensor_dict[saved_tensor] = data_ptr
            return saved_tensor

        def unpack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
            return saved_tensor

        self._saved_tensors_hook = torch.autograd.graph.saved_tensors_hooks(
            pack_hook, unpack_hook
        )

    def __enter__(self) -> "SavedTensorContext":
        self._saved_tensors_hook.__enter__()
        return self

    def __exit__(self, *args: Any, **kwargs: Any) -> None:
        self._saved_tensors_hook.__exit__(*args, **kwargs)

    @property
    def saved_tensor_mem(self) -> int:
        """
        The memory in bytes of all saved tensors, accounting for views into the same storage.
        """
        accounted_for = self._ignored_data_ptrs.copy()
        total_bytes = 0
        for t in self.saved_tensor_dict:
            data_ptr = t.untyped_storage().data_ptr()
            if data_ptr not in accounted_for:
                total_bytes += t.untyped_storage().nbytes()
                accounted_for.add(data_ptr)
        return total_bytes

In [6]:
batch_size, seq_len, d_model = 2, 4096, 1024
dtype = torch.bfloat16
inputs = torch.randn(
    batch_size,
    seq_len,
    d_model,
    device="cuda",
    requires_grad=True,
    dtype=dtype,
)

act_fn_dict = {"ReLU": nn.ReLU(), "GELU": nn.GELU()}
# Append outputs to a list to keep tensors alive
outputs = []
mem_bytes = []

for name, act_fn in act_fn_dict.items():
    mlp = MLP(
        d_model=d_model,
        act_fn=act_fn,
        device="cuda",
        dtype=dtype,
    )
    with AllocatedMemContext() as mem, SavedTensorContext(
        ignored_tensors=mlp.parameters()
    ) as saved:
        out = mlp(inputs)
        outputs.append(out)
    # assert mem.delta["current"] == saved.saved_tensor_mem
    print(f"{name} bytes: {saved.saved_tensor_mem}")
    mem_bytes.append(saved.saved_tensor_mem)

print(f"ReLU/GeLU act mem ratio: {mem_bytes[0]/mem_bytes[1]}")

ReLU bytes: 83886080
GELU bytes: 150994944
ReLU/GeLU act mem ratio: 0.5555555555555556


In [8]:
# Pytorch weakkeytensor

import torch
from torch.utils.weak import WeakTensorKeyDictionary

# Create a WeakTensorKeyDictionary
weak_dict = WeakTensorKeyDictionary()

# Create a tensor and add it to the dictionary
tensor1 = torch.tensor([1, 2, 3])
weak_dict[tensor1] = "This is tensor 1"

print("Before deleting tensor1:")
print(weak_dict, len(weak_dict), weak_dict[tensor1], list(weak_dict.keys()))

# Delete the reference to the tensor
del tensor1

# Now tensor1 can be garbage collected
# Since we deleted the only strong reference to tensor1,
# it should also be removed from weak_dict
print("After deleting tensor1:")
print(weak_dict, len(weak_dict), list(weak_dict.keys()))


Before deleting tensor1:
<WeakIdKeyDictionary at 0x73476421c460> 1 This is tensor 1 [tensor([1, 2, 3])]
After deleting tensor1:
<WeakIdKeyDictionary at 0x73476421c460> 0 []


In [12]:
import torch 
import weakref

class WeakTensorList:
    def __init__(self):
        self._refs = []

    def append(self, tensor):
        # Add a weak reference to the tensor
        self._refs.append(weakref.ref(tensor))

    def __getitem__(self, index):
        # Retrieve the tensor, if it's still alive
        tensor_ref = self._refs[index]()
        if tensor_ref is None:
            print(f"Tensor at index {index} has been garbage collected.")
        return tensor_ref

    def __len__(self):
        return len(self._refs)

    def cleanup(self):
        # Clean up any None references from the list
        self._refs = [ref for ref in self._refs if ref() is not None]

# Usage example
tensor_list = WeakTensorList()
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([4, 5, 6])

tensor_list.append(tensor1)
tensor_list.append(tensor2)

print("Before deleting tensors:")
for i in range(len(tensor_list)):
    print(f"Tensor at index {i}: {tensor_list[i]}")
    print(tensor_list._refs[i]())

# Delete the strong references to the tensors
del tensor1
del tensor2

print("\nAfter deleting tensors:")
for i in range(len(tensor_list)):
    print(f"Tensor at index {i}: {tensor_list[i]}")

# Clean up any None references
tensor_list.cleanup()
print("\nAfter cleanup, length of list:", len(tensor_list))


Before deleting tensors:
Tensor at index 0: tensor([1, 2, 3])
tensor([1, 2, 3])
Tensor at index 1: tensor([4, 5, 6])
tensor([4, 5, 6])

After deleting tensors:
Tensor at index 0 has been garbage collected.
Tensor at index 0: None
Tensor at index 1 has been garbage collected.
Tensor at index 1: None

After cleanup, length of list: 0


In [13]:
a = [1,2,3]

a[:-1]

[1, 2]