### Testing with MSE & ReLU

In [None]:
import torch

class MemoryEfficientLinear(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, linear, labels, forward_function):
        num_chunks = 4
        chunks = X.chunk(num_chunks, dim=0)
        if labels.dim() == 2:
            split_sizes = [chunk.size(0) * X.size(1) for chunk in chunks]
            labels_flat = labels.view(-1)
            labels_chunks = torch.split(labels_flat, split_sizes, dim=0)
        else:
            split_sizes = None
            labels_chunks = labels.chunk(num_chunks, dim=0)

        sum_loss = 0.0
        total_elements = 0
        elements_in_chunks = []
        for x_chunk, labels_chunk in zip(chunks, labels_chunks):
            x_chunk = x_chunk.detach().requires_grad_(True)
            chunk_loss = forward_function(x_chunk, linear, labels_chunk)
            elements_in_chunk = labels_chunk.numel()
            sum_loss += chunk_loss * elements_in_chunk
            total_elements += elements_in_chunk
            elements_in_chunks.append(elements_in_chunk)

        final_loss = sum_loss / total_elements
        ctx.save_for_backward(X, labels)
        ctx.linear = linear
        ctx.forward_function = forward_function
        ctx.labels_dim = labels.dim()
        ctx.split_sizes = split_sizes  # valid only if labels.dim() == 2
        ctx.elements_in_chunks = elements_in_chunks
        ctx.total_elements = total_elements
        ctx.num_chunks = num_chunks
        return final_loss

    @staticmethod
    def backward(ctx, dY):
        X, labels = ctx.saved_tensors
        linear = ctx.linear
        forward_function = ctx.forward_function
        num_chunks = ctx.num_chunks
        elements_in_chunks = ctx.elements_in_chunks
        total_elements = ctx.total_elements
        X_chunks = X.chunk(num_chunks, dim=0)
        if ctx.labels_dim == 2:
            split_sizes = ctx.split_sizes
            labels_flat = labels.view(-1)
            labels_chunks = torch.split(labels_flat, split_sizes, dim=0)
        else:
            labels_chunks = labels.chunk(num_chunks, dim=0)

        dX = torch.zeros_like(X)
        dW = torch.zeros_like(linear.weight)
        dB = torch.zeros_like(linear.bias) if linear.bias is not None else None

        for i in range(num_chunks):
            x_chunk = X_chunks[i].clone().detach().requires_grad_(True)
            labels_chunk = labels_chunks[i]
            elements_in_chunk = elements_in_chunks[i]
            scale = dY * (elements_in_chunk / total_elements)
            with torch.enable_grad():
                chunk_loss = forward_function(x_chunk, linear, labels_chunk)
                grad_x, grad_w, grad_b = torch.autograd.grad(
                    chunk_loss,
                    (x_chunk, linear.weight, linear.bias),
                    grad_outputs=scale,
                    retain_graph=False,
                    allow_unused=False,
                )
            start_idx = i * x_chunk.size(0)
            end_idx = start_idx + x_chunk.size(0)
            dX[start_idx:end_idx] = grad_x
            dW += grad_w
            if grad_b is not None:
                dB += grad_b
        return dX, None, None, None

# ---------------- Setup for CrossEntropy Test ----------------
batch_size = 64
seq_len = 512
hidden_dim = 768
vocab_size = 16000
device = 'cuda'

# Create input and labels with requires_grad=True for input_data
input_data = torch.randn(batch_size, seq_len, hidden_dim,
                         dtype=torch.bfloat16, device=device, requires_grad=True)
labels = torch.randint(0, vocab_size, (batch_size, seq_len),
                       dtype=torch.long, device=device)

def create_linear():
    linear = torch.nn.Linear(hidden_dim, vocab_size, device=device)
    linear.weight = torch.nn.Parameter(linear.weight.to(torch.bfloat16))
    if linear.bias is not None:
        linear.bias.data = linear.bias.data.to(torch.bfloat16)
    return linear

In [2]:
targets_mse = torch.randn(batch_size, seq_len, vocab_size, dtype=torch.float32, device=device)

def transformation_function_mse(batch, linear, targets):
    x = linear(batch).float()
    # x = torch.relu(x)  # Activation
    mse_loss_fn = torch.nn.MSELoss(reduction="mean")
    loss = mse_loss_fn(x, targets)
    return loss

linear_standard_mse = create_linear()
linear_custom_mse = create_linear()
linear_custom_mse.load_state_dict(linear_standard_mse.state_dict())

def standard_forward_backward_mse():
    linear_standard_mse.zero_grad()
    if input_data.grad is not None:
        input_data.grad.zero_()
    logits = linear_standard_mse(input_data)
    activated = torch.relu(logits.float())
    loss = torch.nn.MSELoss(reduction='mean')(activated, targets_mse)
    loss.backward()
    return loss.detach()

def custom_forward_backward_mse():
    linear_custom_mse.zero_grad()
    if input_data.grad is not None:
        input_data.grad.zero_()
    loss = MemoryEfficientLinear.apply(input_data, linear_custom_mse, targets_mse, transformation_function_mse)
    loss.backward()
    return loss.detach()

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
loss_standard_mse = standard_forward_backward_mse()

torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
loss_custom_mse = custom_forward_backward_mse()
loss_diff_mse = torch.abs(loss_standard_mse - loss_custom_mse).item()

print("\nMSE Loss Test (with ReLU activation):")
print(f"Standard MSE Loss: {loss_standard_mse.item():.6f}")
print(f"Custom MSE Loss: {loss_custom_mse.item():.6f}")
print(f"MSE Loss Difference: {loss_diff_mse:.6f}")


MSE Loss Test (with ReLU activation):
Standard MSE Loss: 1.166796
Custom MSE Loss: 1.333750
MSE Loss Difference: 0.166953
