### This showcases a way for implementing memory efficient linear layer

SOURCE: https://medium.com/@yash9439/unslothais-innovative-hiring-challenge-memory-efficient-backprop-a5dc2372d469

In [3]:
import torch
from torch.nn import CrossEntropyLoss
import time

def transformation_function(batch, linear, labels):
    x = linear(batch).float()
    down_projection_function = CrossEntropyLoss(reduction="mean")
    loss = down_projection_function(x.view(-1, x.shape[-1]), labels.view(-1))
    return loss

class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        num_chunks = 4
        chunks = X.chunk(num_chunks, dim=0)
        if labels.dim() == 2:
            split_sizes = [chunk.size(0) * X.size(1) for chunk in chunks]
            labels_flat = labels.view(-1)
            labels_chunks = torch.split(labels_flat, split_sizes, dim=0)
        else:
            split_sizes = None
            labels_chunks = labels.chunk(num_chunks, dim=0)

        sum_loss = 0.0
        total_elements = 0
        elements_in_chunks = []
        for x_chunk, labels_chunk in zip(chunks, labels_chunks):
            x_chunk = x_chunk.detach().requires_grad_(True)
            chunk_loss = forward_function(x_chunk, linear, labels_chunk)
            elements_in_chunk = labels_chunk.numel()
            sum_loss += chunk_loss * elements_in_chunk
            total_elements += elements_in_chunk
            elements_in_chunks.append(elements_in_chunk)

        final_loss = sum_loss / total_elements
        ctx.save_for_backward(X, labels)
        ctx.linear = linear
        ctx.forward_function = forward_function
        ctx.labels_dim = labels.dim()
        ctx.split_sizes = split_sizes  # valid only if labels.dim() == 2
        ctx.elements_in_chunks = elements_in_chunks
        ctx.total_elements = total_elements
        ctx.num_chunks = num_chunks
        return final_loss

    @staticmethod
    def backward(ctx, dY):
        X, labels = ctx.saved_tensors
        linear = ctx.linear
        forward_function = ctx.forward_function
        num_chunks = ctx.num_chunks
        elements_in_chunks = ctx.elements_in_chunks
        total_elements = ctx.total_elements
        X_chunks = X.chunk(num_chunks, dim=0)
        if ctx.labels_dim == 2:
            split_sizes = ctx.split_sizes
            labels_flat = labels.view(-1)
            labels_chunks = torch.split(labels_flat, split_sizes, dim=0)
        else:
            labels_chunks = labels.chunk(num_chunks, dim=0)

        dX = torch.zeros_like(X)
        dW = torch.zeros_like(linear.weight)
        dB = torch.zeros_like(linear.bias) if linear.bias is not None else None

        for i in range(num_chunks):
            x_chunk = X_chunks[i].clone().detach().requires_grad_(True)
            labels_chunk = labels_chunks[i]
            elements_in_chunk = elements_in_chunks[i]
            scale = dY * (elements_in_chunk / total_elements)
            with torch.enable_grad():
                chunk_loss = forward_function(x_chunk, linear, labels_chunk)
                grad_x, grad_w, grad_b = torch.autograd.grad(
                    chunk_loss,
                    (x_chunk, linear.weight, linear.bias),
                    grad_outputs=scale,
                    retain_graph=False,
                    allow_unused=False,
                )
            start_idx = i * x_chunk.size(0)
            end_idx = start_idx + x_chunk.size(0)
            dX[start_idx:end_idx] = grad_x
            dW += grad_w
            if grad_b is not None:
                dB += grad_b
        return dX, None, None, None

# ---------------- Setup for CrossEntropy Test ----------------
batch_size = 64
seq_len = 512
hidden_dim = 768
vocab_size = 16000
device = 'cuda'

# Create input and labels with requires_grad=True for input_data
input_data = torch.randn(batch_size, seq_len, hidden_dim,
                         dtype=torch.bfloat16, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (batch_size, seq_len),
                       dtype=torch.long, device=device)

def create_linear():
    linear = torch.nn.Linear(hidden_dim, vocab_size, device=device)
    linear.weight = torch.nn.Parameter(linear.weight.to(torch.bfloat16))
    if linear.bias is not None:
        linear.bias.data = linear.bias.data.to(torch.bfloat16)
    return linear

### Testing

In [2]:
linear_standard = create_linear()
linear_custom = create_linear()
linear_custom.load_state_dict(linear_standard.state_dict())

def standard_forward_backward():
    linear_standard.zero_grad()
    if input_data.grad is not None:
        input_data.grad.zero_()
    logits = linear_standard(input_data)
    loss = CrossEntropyLoss(reduction='mean')(logits.view(-1, vocab_size), labels.view(-1))
    loss.backward()
    return loss.detach()

def custom_forward_backward():
    linear_custom.zero_grad()
    if input_data.grad is not None:
        input_data.grad.zero_()
    loss = MemoryEfficientLinear.apply(input_data, linear_custom, labels, transformation_function)
    loss.backward()
    return loss.detach()

# Warmup to avoid CUDA initialization overhead
for _ in range(2):
    standard_forward_backward()
    custom_forward_backward()

# ---------------- CrossEntropy Timing and VRAM Measurements ----------------
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start_time = time.time()
loss_standard = standard_forward_backward()
torch.cuda.synchronize()
time_standard = time.time() - start_time

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
start_time = time.time()
loss_custom = custom_forward_backward()
torch.cuda.synchronize()
time_custom = time.time() - start_time

print(f"Standard Loss (CE): {loss_standard.item():.6f}")
print(f"Custom Loss (CE): {loss_custom.item():.6f}")
loss_diff = torch.abs(loss_standard - loss_custom).item()
print(f"CE Loss Difference: {loss_diff:.6f}")

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
standard_forward_backward()
mem_standard = torch.cuda.max_memory_allocated()

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
custom_forward_backward()
mem_custom = torch.cuda.max_memory_allocated()

print(f"\nStandard Time (CE): {time_standard:.6f} seconds")
print(f"Custom Time (CE): {time_custom:.6f} seconds")
print(f"Standard VRAM (CE): {mem_standard / (1024**3):.6f} GiB")
print(f"Custom VRAM (CE): {mem_custom / (1024**3):.6f} GiB")

# ---------------- CrossEntropy Input Gradients Comparison ----------------
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
if input_data.grad is not None:
    input_data.grad.zero_()
loss_standard = standard_forward_backward()
grad_standard = input_data.grad.clone()

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
if input_data.grad is not None:
    input_data.grad.zero_()
loss_custom = custom_forward_backward()
grad_custom = input_data.grad.clone()

print("\nInput gradient comparison (CE):")
max_grad_diff = torch.max(torch.abs(grad_standard - grad_custom)).item()
print("Max difference between standard and custom input gradients:", max_grad_diff)

Standard Loss (CE): 9.875000
Custom Loss (CE): 9.850885
CE Loss Difference: 0.024115

Standard Time (CE): 1.023671 seconds
Custom Time (CE): 1.390536 seconds
Standard VRAM (CE): 4.063050 GiB
Custom VRAM (CE): 1.760712 GiB

Input gradient comparison (CE):
Max difference between standard and custom input gradients: 7.450580596923828e-09
