<a href="https://colab.research.google.com/github/Datbwoyyy/SlothAi/blob/main/Untitled16.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [7]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import gc
from contextlib import contextmanager
import torch.utils.checkpoint as checkpoint  # Correct import

# Use torch.cuda.amp.autocast if a GPU is available; otherwise, create a dummy context.
if torch.cuda.is_available():
    from torch.cuda.amp import autocast
else:
    @contextmanager
    def autocast():
        yield

class MemoryEfficientFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, W, transform_fn):
        """
        Forward pass using PyTorch checkpointing to save memory.
        """
        ctx.save_for_backward(X, W)
        ctx.transform_fn = transform_fn

        def chunk_forward(x):
            """Compute logits & apply transformation for a given chunk."""
            logits = x.matmul(W)  # X @ W (Matrix Multiplication)
            return transform_fn(logits)  # Apply activation (log_softmax, sigmoid, etc.)

        # Use torch.utils.checkpoint to save memory
        # Use preserve_rng_state=False as it is deprecated
        output = checkpoint.checkpoint(chunk_forward, X, use_reentrant=False)
        return output  # Correctly return the computed output

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: Recompute forward function using checkpointing and propagate gradients.
        """
        X, W = ctx.saved_tensors
        transform_fn = ctx.transform_fn

        # Recompute the forward pass
        def chunk_forward(x):
            logits = x.matmul(W)
            return transform_fn(logits)

        # X.requires_grad_(True)  # This is redundant as X already has requires_grad=True from forward
        with torch.enable_grad(): # Enable gradient computation within this block
            logits_recomputed = chunk_forward(X)  # Recompute forward pass

        # Compute gradients
        grad_X, grad_W = torch.autograd.grad(logits_recomputed, (X, W), grad_outputs=grad_output, retain_graph=True)

        return grad_X, grad_W, None

# Convenience wrapper.
def memory_efficient_forward(X, W, transform_fn):
    return MemoryEfficientFunction.apply(X, W, transform_fn)

# --- Example Usages ---

# Example 1: Cross Entropy Loss (using log softmax)
def example_cross_entropy():
    bsz, qlen, hd, vocab = 4, 1024, 1024, 32000  # Adjusted for Colab
    X = torch.randn(bsz * qlen, hd, dtype=torch.float16, requires_grad=True)
    W = torch.randn(hd, vocab, dtype=torch.float16, requires_grad=True)
    targets = torch.randint(0, vocab, (bsz * qlen,))

    if torch.cuda.is_available():
        X, W, targets = X.cuda(), W.cuda(), targets.cuda()

    transform_fn = lambda logits: F.log_softmax(logits, dim=-1)  # Log Softmax for Cross Entropy
    output = memory_efficient_forward(X, W, transform_fn)
    loss = F.nll_loss(output, targets)
    loss.backward()
    print("Example 1 (Cross Entropy): Loss =", loss.item())

# Example 2: Sigmoid Activation with MSE Loss
def example_sigmoid_mse():
    bsz, qlen, hd, vocab = 4, 1024, 1024, 32000
    X = torch.randn(bsz * qlen, hd, dtype=torch.float16, requires_grad=True)
    W = torch.randn(hd, vocab, dtype=torch.float16, requires_grad=True)
    targets = torch.randn(bsz * qlen, vocab, dtype=torch.float16)

    if torch.cuda.is_available():
        X, W, targets = X.cuda(), W.cuda(), targets.cuda()

    transform_fn = torch.sigmoid  # Sigmoid Activation
    output = memory_efficient_forward(X, W, transform_fn)
    loss = F.mse_loss(output, targets)
    loss.backward()
    print("Example 2 (Sigmoid + MSE): Loss =", loss.item())

if __name__ == '__main__':
    example_cross_entropy()
    example_sigmoid_mse()


Example 1 (Cross Entropy): Loss = 131.125
Example 2 (Sigmoid + MSE): Loss = 1.4853515625


In [11]:
import torch
import torch.nn.functional as F
from tqdm import tqdm
import gc
from contextlib import contextmanager
import torch.utils.checkpoint as checkpoint  # Correct import

# Use torch.cuda.amp.autocast if a GPU is available; otherwise, create a dummy context.
if torch.cuda.is_available():
    from torch.cuda.amp import autocast
else:
    @contextmanager
    def autocast():
        yield

# Enable TF32 for faster matmul on NVIDIA GPUs
torch.backends.cuda.matmul.allow_tf32 = True

class MemoryEfficientFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, X, W, transform_fn):
        """
        Forward pass using PyTorch checkpointing to save memory.
        """
        ctx.save_for_backward(X, W)
        ctx.transform_fn = transform_fn

        def chunk_forward(x):
            """Compute logits & apply transformation for a given chunk."""
            logits = x.matmul(W)  # X @ W (Matrix Multiplication)
            return transform_fn(logits)  # Apply activation (log_softmax, sigmoid, etc.)

        # Use torch.utils.checkpoint to save memory
        output = checkpoint.checkpoint(chunk_forward, X, use_reentrant=False)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        """
        Backward pass: Recompute forward function using checkpointing and propagate gradients.
        """
        X, W = ctx.saved_tensors
        transform_fn = ctx.transform_fn

        # Recompute the forward pass
        def chunk_forward(x):
            logits = x.matmul(W)
            return transform_fn(logits)

        # Create a detached copy of X that requires grad (ensuring it's a leaf variable)
        with torch.enable_grad():
            X_ = X.detach().requires_grad_(True)
            logits_recomputed = checkpoint.checkpoint(chunk_forward, X_, use_reentrant=False)

        # Compute gradients with respect to X_ and W
        grad_X, grad_W = torch.autograd.grad(
            logits_recomputed,
            (X_, W),
            grad_outputs=grad_output,
            retain_graph=False
        )

        return grad_X, grad_W, None

# Convenience wrapper.
def memory_efficient_forward(X, W, transform_fn):
    return MemoryEfficientFunction.apply(X, W, transform_fn)

# --- Example Usages ---
def get_dtype():
    return torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16

# Example 1: Cross Entropy Loss (using log softmax)
def example_cross_entropy():
    bsz, qlen, hd, vocab = 4, 1024, 1024, 32000  # Adjusted for Colab
    dtype = get_dtype()
    X = torch.randn(bsz * qlen, hd, dtype=dtype, requires_grad=True)
    W = torch.randn(hd, vocab, dtype=dtype, requires_grad=True)
    targets = torch.randint(0, vocab, (bsz * qlen,))

    if torch.cuda.is_available():
        X, W, targets = X.cuda(), W.cuda(), targets.cuda()

    transform_fn = lambda logits: F.log_softmax(logits, dim=-1)  # Log Softmax for Cross Entropy
    output = memory_efficient_forward(X, W, transform_fn)
    loss = F.nll_loss(output, targets)
    loss.backward()
    print("Example 1 (Cross Entropy): Loss =", loss.item())

# Example 2: Sigmoid Activation with MSE Loss
def example_sigmoid_mse():
    bsz, qlen, hd, vocab = 4, 1024, 1024, 32000
    dtype = get_dtype()
    X = torch.randn(bsz * qlen, hd, dtype=dtype, requires_grad=True)
    W = torch.randn(hd, vocab, dtype=dtype, requires_grad=True)
    targets = torch.randn(bsz * qlen, vocab, dtype=dtype)

    if torch.cuda.is_available():
        X, W, targets = X.cuda(), W.cuda(), targets.cuda()

    transform_fn = torch.sigmoid  # Sigmoid Activation
    output = memory_efficient_forward(X, W, transform_fn)
    loss = F.mse_loss(output, targets)
    loss.backward()
    print("Example 2 (Sigmoid + MSE): Loss =", loss.item())

if __name__ == '__main__':
    example_cross_entropy()
    example_sigmoid_mse()


Example 1 (Cross Entropy): Loss = 131.875
Example 2 (Sigmoid + MSE): Loss = 1.4853515625
