In [1]:
%%capture
!pip install --upgrade pip
!pip install torch
!pip install matplotlib

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
from collections import defaultdict

# Main implementation with forward and backward passes

class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, weight, bias, labels, vocab_chunk_size: int, logit_transform=None):
        """
        Computes memory‐efficient cross entropy loss.
        """
        ctx.logit_transform = logit_transform
        original_shape = X.shape
        if X.dim() > 2:
            X_flat = X.view(-1, X.shape[-1])
            labels_flat = labels.view(-1)
        else:
            X_flat = X
            labels_flat = labels


        N, hidden = X_flat.shape
        vocab = weight.shape[0]
        device, dtype = X.device, X.dtype
        global_max = torch.full((N,), -float('inf'), device=device, dtype=dtype)
        ground_truth = torch.empty(N, device=device, dtype=dtype)

        for j in range(0, vocab, vocab_chunk_size):
            j_end = min(j + vocab_chunk_size, vocab)
            z_chunk = X_flat @ weight[j:j_end].t()
            if bias is not None:
                z_chunk = z_chunk + bias[j:j_end].unsqueeze(0)
            if logit_transform is not None:
                tz_chunk = logit_transform(z_chunk)
            else:
                tz_chunk = z_chunk

            global_max = torch.maximum(global_max, tz_chunk.max(dim=1)[0])
            mask = (labels_flat >= j) & (labels_flat < j_end)
            if mask.any():
                ground_truth[mask] = tz_chunk[mask, labels_flat[mask] - j]

        sum_exp = torch.zeros(N, device=device, dtype=dtype)
        for j in range(0, vocab, vocab_chunk_size):
            j_end = min(j + vocab_chunk_size, vocab)
            z_chunk = X_flat @ weight[j:j_end].t()
            if bias is not None:
                z_chunk = z_chunk + bias[j:j_end].unsqueeze(0)
            if logit_transform is not None:
                tz_chunk = logit_transform(z_chunk)
            else:
                tz_chunk = z_chunk
            sum_exp += torch.exp(tz_chunk - global_max.unsqueeze(1)).sum(dim=1)

        loss = (global_max + torch.log(sum_exp + 1e-12) - ground_truth).mean()

        ctx.save_for_backward(X_flat, weight, bias, labels_flat, global_max, sum_exp)
        ctx.vocab_chunk_size = vocab_chunk_size
        ctx.N = N
        ctx.hidden = hidden
        ctx.original_shape = original_shape
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        X_flat, weight, bias, labels_flat, global_max, sum_exp = ctx.saved_tensors
        vocab_chunk_size = ctx.vocab_chunk_size
        N, hidden = ctx.N, ctx.hidden
        vocab = weight.shape[0]
        device, dtype = X_flat.device, X_flat.dtype
        logit_transform = ctx.logit_transform

        grad_X = torch.zeros_like(X_flat)
        grad_W = torch.zeros_like(weight)
        grad_b = torch.zeros_like(bias) if bias is not None else None

        scale = grad_output / N

        for j in range(0, vocab, vocab_chunk_size):
            j_end = min(j + vocab_chunk_size, vocab)
            z_chunk = X_flat @ weight[j:j_end].t()
            if bias is not None:
                z_chunk = z_chunk + bias[j:j_end].unsqueeze(0)

            if logit_transform is not None:
                with torch.enable_grad():
                    z_chunk = z_chunk.detach().clone().requires_grad_(True)
                    tz_chunk = logit_transform(z_chunk)
                    dummy = tz_chunk.sum()
                    dtz_dz = torch.autograd.grad(dummy, z_chunk, create_graph=True)[0]
                current_tz = tz_chunk.detach()
            else:
                current_tz = z_chunk
                dtz_dz = None

            logits_chunk_stable = current_tz - global_max.unsqueeze(1)
            exp_chunk = torch.exp(logits_chunk_stable)
            p = exp_chunk / sum_exp.unsqueeze(1)

            mask = (labels_flat >= j) & (labels_flat < j_end)
            if mask.any():
                indices = (labels_flat[mask] - j).unsqueeze(1)

                p[mask].scatter_(1, indices, p[mask].gather(1, indices) - 1)

            dlogits = p * scale
            if logit_transform is not None:
                dlogits_effective = dlogits * dtz_dz
            else:
                dlogits_effective = dlogits

            # Accumulate gradients.
            grad_X += dlogits_effective @ weight[j:j_end]
            grad_W[j:j_end] += dlogits_effective.t() @ X_flat
            if grad_b is not None:
                grad_b[j:j_end] += dlogits_effective.sum(dim=0)

        grad_X = grad_X.view(ctx.original_shape)
        return grad_X, grad_W, grad_b, None, None, None


# from intermeditate tensors for autograd
def memory_efficient_loss(X, linear, labels, batch_chunk_size=None, vocab_chunk_size=1024, logit_transform=None):
    return MemoryEfficientLinear.apply(X, linear.weight, linear.bias, labels, vocab_chunk_size, logit_transform)

In [3]:
### Test code ###
def baseline_loss(X, linear, labels):
    if X.dim() == 3:
        X = X.view(-1, X.shape[-1])  # B * seq_len, hidden_dim
        labels = labels.view(-1)      # B *seq_len
    logits = linear(X)  # B * seq_len, vocab_size
    if linear.bias is not None:
        logits = logits + linear.bias
    loss_fn = nn.CrossEntropyLoss(reduction="mean")
    loss = loss_fn(logits, labels)
    loss.backward()
    return loss

def _test_vram_reduction(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device):
    print("\n--- Test: VRAM Reduction ---")
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    base_memory = torch.cuda.memory_allocated()
    print(f"Base memory: {base_memory/1024**2:.2f} MB")

    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear_standard = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)
    loss_standard = baseline_loss(X.clone(), linear_standard, labels)
    vram_standard = torch.cuda.max_memory_allocated() - base_memory

    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear_efficient = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)
    loss_efficient = memory_efficient_loss(X.clone(), linear_efficient, labels, batch_chunk_size=batch_chunk_size, vocab_chunk_size=vocab_chunk_size)
    loss_efficient.backward()
    vram_efficient = torch.cuda.max_memory_allocated() - base_memory

    reduction = (1 - (vram_efficient / vram_standard)) * 100
    passed = reduction >= 40
    print(f"Standard VRAM delta: {vram_standard/1024**2:.2f} MB")
    print(f"Efficient VRAM delta: {vram_efficient/1024**2:.2f} MB")
    print(f"VRAM Reduction: {reduction:.2f}%  ->  {'PASS' if passed else 'FAIL'}")
    return passed, {"vram_standard_mb": vram_standard/1024**2, "vram_efficient_mb": vram_efficient/1024**2, "vr_reduction_percent": reduction}

def _test_numerical_correctness(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device):
    print("\n--- Test: Numerical Correctness ---")
    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear_standard = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)
    linear_efficient = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)

    with torch.no_grad():
        linear_efficient.weight.copy_(linear_standard.weight)
        if linear_standard.bias is not None:
            linear_efficient.bias.copy_(linear_standard.bias)

    loss_standard = baseline_loss(X.clone(), linear_standard, labels)
    grad_standard = linear_standard.weight.grad.clone()

    linear_standard.zero_grad()
    loss_efficient = memory_efficient_loss(X.clone(), linear_efficient, labels, batch_chunk_size=batch_chunk_size, vocab_chunk_size=vocab_chunk_size)
    loss_efficient.backward()
    grad_efficient = linear_efficient.weight.grad.clone()

    loss_close = torch.allclose(loss_standard, loss_efficient, rtol=1e-1, atol=1e-1)
    grad_close = torch.allclose(grad_standard, grad_efficient, rtol=1e-1, atol=1e-1)
    print(f"Loss match: {loss_close}")
    print(f"Gradient match: {grad_close}")
    passed = loss_close and grad_close
    return passed, {"loss_standard": loss_standard.item(), "loss_efficient": loss_efficient.item()}

def _test_dtype_handling(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device):
    print("\n--- Test: Data Type Handling ---")
    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear_efficient = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)
    loss = memory_efficient_loss(X, linear_efficient, labels, batch_chunk_size=batch_chunk_size, vocab_chunk_size=vocab_chunk_size)
    loss.backward()
    input_dtype_correct = (X.dtype == dtype)
    weight_dtype_correct = (linear_efficient.weight.dtype == dtype)
    bias_dtype_correct = ((linear_efficient.bias.dtype == dtype) if linear_efficient.bias is not None else True)
    passed = input_dtype_correct and weight_dtype_correct and bias_dtype_correct
    print(f"Input dtype: {X.dtype} (expected: {dtype})")
    print(f"Weight dtype: {linear_efficient.weight.dtype} (expected: {dtype})")
    print(f"Bias dtype: {linear_efficient.bias.dtype if linear_efficient.bias is not None else 'N/A'} (expected: {dtype})")
    return passed, {"input_dtype": str(X.dtype), "weight_dtype": str(linear_efficient.weight.dtype), "bias_dtype": str(linear_efficient.bias.dtype)}

def _test_show_ce_loss_works(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device):
    print("\n--- Test: Show CE Loss Works ---")
    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear_standard = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)
    linear_efficient = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)

    with torch.no_grad():
        linear_efficient.weight.copy_(linear_standard.weight)
        if linear_standard.bias is not None:
            linear_efficient.bias.copy_(linear_standard.bias)

    loss_standard = baseline_loss(X.clone(), linear_standard, labels).item()
    linear_standard.zero_grad()
    loss_efficient = memory_efficient_loss(X.clone(), linear_efficient, labels, batch_chunk_size=batch_chunk_size, vocab_chunk_size=vocab_chunk_size).item()
    linear_efficient.zero_grad()
    loss_close = torch.isclose(torch.tensor(loss_standard), torch.tensor(loss_efficient), rtol=1e-2, atol=1e-2)
    print(f"Standard CE Loss: {loss_standard:.4f}")
    print(f"Efficient CE Loss: {loss_efficient:.4f}")
    print(f"CE Loss match: {loss_close}")
    return loss_close, {"loss_standard": loss_standard, "loss_efficient": loss_efficient}

def _test_show_other_functions_work(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device):
    print("\n--- Test: Show Other Functions Work ---")

    def dummy_transform(z):
        return torch.log1p(torch.exp(z))

    def dummy_loss(X, linear, labels, vocab_chunk_size=128):
        if X.dim() == 3:
            X = X.view(-1, X.shape[-1])
            labels = labels.view(-1)
        logits = linear(X)
        transformed = torch.log1p(torch.exp(logits))
        loss_fn = nn.CrossEntropyLoss(reduction="mean")
        return loss_fn(transformed, labels)

    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear_standard = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)
    linear_efficient = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)

    with torch.no_grad():
        linear_efficient.weight.copy_(linear_standard.weight)
        if linear_standard.bias is not None:
            linear_efficient.bias.copy_(linear_standard.bias)

    loss_standard = dummy_loss(X.clone(), linear_standard, labels, vocab_chunk_size).item()
    linear_standard.zero_grad()
    loss_efficient = memory_efficient_loss(
        X.clone(),
        linear_efficient,
        labels,
        batch_chunk_size=batch_chunk_size,
        vocab_chunk_size=vocab_chunk_size,
        logit_transform=dummy_transform
    ).item()
    linear_efficient.zero_grad()
    loss_close = torch.isclose(torch.tensor(loss_standard), torch.tensor(loss_efficient), rtol=1e-2, atol=1e-2)
    print(f"Dummy Loss Standard: {loss_standard:.4f}")
    print(f"Dummy Loss Efficient: {loss_efficient:.4f}")
    print(f"Dummy Loss match: {loss_close}")
    return loss_close, {"loss_standard": loss_standard, "loss_efficient": loss_efficient}

def _test_dynamic_chunk_sizes(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, device):
    print("\n--- Test: Dynamic Chunk Sizes ---")
    chunk_size_pairs = [
        (32, 64),
        (128, 256),
        (256, 512),
    ]
    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)

    base_loss = None
    all_losses_match = True
    for vocab_chunk, batch_chunk in chunk_size_pairs:
        print(f"\nTesting with vocab_chunk_size={vocab_chunk}, batch_chunk_size={batch_chunk}")
        loss = memory_efficient_loss(
            X.clone(),
            linear,
            labels,
            batch_chunk_size=batch_chunk,
            vocab_chunk_size=vocab_chunk
        )
        if base_loss is None:
            base_loss = loss.item()
        else:
            loss_matches = torch.isclose(torch.tensor(loss.item()), torch.tensor(base_loss), rtol=1e-2, atol=1e-2)
            if not loss_matches:
                all_losses_match = False
                print(f"Loss mismatch: {loss.item()} vs base {base_loss}")
        del loss
        torch.cuda.empty_cache()
    print(f"All chunk sizes produce consistent results: {'PASS' if all_losses_match else 'FAIL'}")
    return all_losses_match, {"dynamic_chunks_work": all_losses_match}

def _test_hardcoded_gradients(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device):
    print("\n--- Test: Hardcoded Gradients Check ---")
    X = torch.randn(batch_size, query_length, hidden_dimension, device=device, dtype=dtype, requires_grad=True)
    labels = torch.randint(0, vocabulary_size, (batch_size, query_length), device=device)
    linear_efficient = nn.Linear(hidden_dimension, vocabulary_size, bias=True).to(device).to(dtype)
    loss = memory_efficient_loss(X.clone(), linear_efficient, labels, batch_chunk_size=batch_chunk_size, vocab_chunk_size=vocab_chunk_size)
    loss.backward()
    weight_grad_zero = torch.all(linear_efficient.weight.grad == 0)
    bias_grad_zero = torch.all(linear_efficient.bias.grad == 0) if linear_efficient.bias is not None else False
    hardcoded = weight_grad_zero or bias_grad_zero
    print(f"Weight grad all zero: {weight_grad_zero}")
    print(f"Bias grad all zero: {bias_grad_zero}")
    passed = not hardcoded
    return passed, {"hardcoded_gradients": hardcoded}

def test_suite(dtype=torch.float16,
               batch_size=8,
               query_length=2048,
               hidden_dimension=2048,
               vocabulary_size=32000,
               vocab_chunk_size=128,
               batch_chunk_size=64,
               device="cuda",
               epochs=1,
               steps_per_epoch=10):

    test_results = defaultdict(lambda: {"result": False, "stats": {}})
    if device == "cuda" and not torch.cuda.is_available():
        print("CUDA not available, switching to CPU.")
        device = "cpu"

    print("\n=== Running Streamlined Test Suite ===")
    print(f"Device: {device}, dtype: {dtype}")
    print(f"batch_size: {batch_size}, query_length: {query_length}, hidden_dimension: {hidden_dimension}, vocabulary_size: {vocabulary_size}")
    print(f"vocab_chunk_size: {vocab_chunk_size}, batch_chunk_size: {batch_chunk_size}")

    E_score = 0

    passed, stats = _test_vram_reduction(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device)
    test_results["VRAM Reduction"] = {"result": passed, "stats": stats}
    print(f"VRAM Reduction Test: {'PASS' if passed else 'FAIL'}")
    if passed: E_score += 2

    passed, stats = _test_numerical_correctness(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device)
    test_results["Numerical Correctness"] = {"result": passed, "stats": stats}
    print(f"Numerical Correctness Test: {'PASS' if passed else 'FAIL'}")
    if passed: E_score += 0

    passed, stats = _test_dtype_handling(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device)
    test_results["Data Type Handling"] = {"result": passed, "stats": stats}
    print(f"Data Type Handling Test: {'PASS' if passed else 'FAIL'}")
    if not passed:
        print("Data type handling failed. E_score set to 0.")
        E_score = 0

    passed, stats = _test_show_ce_loss_works(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device)
    test_results["CE Loss Works"] = {"result": passed, "stats": stats}
    print(f"CE Loss Works Test: {'PASS' if passed else 'FAIL'}")
    if passed: E_score += 1

    passed, stats = _test_show_other_functions_work(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device)
    test_results["Other Functions Work"] = {"result": passed, "stats": stats}
    print(f"Other Functions Work Test: {'PASS' if passed else 'FAIL'}")
    if passed: E_score += 1

    passed, stats = _test_hardcoded_gradients(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, vocab_chunk_size, batch_chunk_size, device)
    test_results["Hardcoded Gradients"] = {"result": passed, "stats": stats}
    print(f"Hardcoded Gradients Test: {'PASS' if passed else 'FAIL'}")
    if not passed:
        print("Detected hardcoded gradients. E_score set to 0.")
        E_score = 0

    passed, stats = _test_dynamic_chunk_sizes(dtype, batch_size, query_length, hidden_dimension, vocabulary_size, device)
    test_results["Dynamic Chunk Sizes"] = {"result": passed, "stats": stats}
    print(f"Dynamic Chunk Sizes Test: {'PASS' if passed else 'FAIL'}")
    if passed:
        E_score += 1
        test_results["Total Score"] = {"result": (E_score == 10), "stats": {"E_score": E_score}}
        print(f"\n=== Test Suite Finished ===")
        print(f"Total Score: {E_score} / 10")
    else:
        test_results["Total Score"] = {"result": (E_score == 10), "stats": {"E_score": E_score}}
        print(f"\n=== Test Suite Finished ===")
        print(f"Total Score: {E_score} / 10")
    return test_results

if __name__ == "__main__":
    results = test_suite(
        dtype=torch.float16,
        batch_size=8,
        query_length=2048,
        hidden_dimension=2048,
        vocabulary_size=32000,
        vocab_chunk_size=128,
        batch_chunk_size=64,
        device="cuda",
        epochs=1,
        steps_per_epoch=10
    )


=== Running Streamlined Test Suite ===
Device: cuda, dtype: torch.float16
batch_size: 8, query_length: 2048, hidden_dimension: 2048, vocabulary_size: 32000
vocab_chunk_size: 128, batch_chunk_size: 64

--- Test: VRAM Reduction ---
Base memory: 0.00 MB
Standard VRAM delta: 4262.31 MB
Efficient VRAM delta: 921.70 MB
VRAM Reduction: 78.38%  ->  PASS
VRAM Reduction Test: PASS

--- Test: Numerical Correctness ---
Loss match: True
Gradient match: True
Numerical Correctness Test: PASS

--- Test: Data Type Handling ---
Input dtype: torch.float16 (expected: torch.float16)
Weight dtype: torch.float16 (expected: torch.float16)
Bias dtype: torch.float16 (expected: torch.float16)
Data Type Handling Test: PASS

--- Test: Show CE Loss Works ---
Standard CE Loss: 10.5469
Efficient CE Loss: 10.5469
CE Loss match: True
CE Loss Works Test: PASS

--- Test: Show Other Functions Work ---
Dummy Loss Standard: 10.4141
Dummy Loss Efficient: 10.4141
Dummy Loss match: True
Other Functions Work Test: PASS

--- Te