[Referencce](https://levelup.gitconnected.com/how-to-optimize-memory-usage-for-training-llms-in-pytorch-b012f3008798)

# 1. Automatic Mixed-Precision Training

In [2]:
import torch
from torch.cuda.amp import autocast, GradScaler

# Assume your model and optimizer have been defined elsewhere.
model = MyModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scaler = GradScaler()

for data, target in data_loader:
    optimizer.zero_grad()
    # Enable mixed precision
    with autocast():
        output = model(data)
        loss = loss_fn(output, target)

    # Scale the loss and backpropagate
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

# 2. Lower-Precision Training

In [3]:
import torch
print(torch.cuda.is_bf16_supported())  # should print True

True


# 3. Gradient Check-Pointing

In [5]:
import torch
from torch.utils.checkpoint import checkpoint

def checkpointed_segment(input_tensor):
    # This function represents a portion of your model
    # which will be recomputed during the backward pass.
    # You can create a custom forward pass for this segment.
    return model_segment(input_tensor)

# Instead of a conventional forward pass, wrap the segment with checkpoint.
output = checkpoint(checkpointed_segment, input_tensor)

# 4. Tensor Sharding and Distributed Training

In [6]:
import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

# Initialize your model and ensure it is on the correct device.
model = MyLargeModel().cuda()

# Wrap the model in FSDP for sharded training across GPUs.
fsdp_model = FSDP(model)

# 5. Efficient Data Loading

In [7]:
from torch.utils.data import DataLoader

# Create your dataset instance and then the DataLoader with pinned memory enabled.
train_loader = DataLoader(
    dataset,
    batch_size=64,
    shuffle=True,
    num_workers=4,      # Adjust based on your CPU capabilities
    pin_memory=True     # Enables faster host-to-device transfers
)

# 6. Use In-Place Operations

In [8]:
import torch

x = torch.randn(100, 100, device='cuda')
y = torch.randn(100, 100, device='cuda')

# Using in-place addition
x.add_(y)  # Here x is modified directly instead of creating a new tensor

tensor([[-3.3012,  0.4687,  1.3011,  ...,  0.4248, -0.6974,  0.3895],
        [-0.1602, -1.4334, -0.1336,  ..., -1.0948, -1.0826,  1.3182],
        [ 1.1904,  2.6363,  0.8543,  ...,  1.7735,  0.4995,  1.3008],
        ...,
        [-2.7956, -2.2139, -1.1450,  ...,  0.8583, -0.9636,  0.0385],
        [ 1.0686,  1.5596, -0.9039,  ..., -0.6360, -1.5732,  2.1149],
        [ 0.3814,  1.7975,  2.0598,  ...,  0.2565, -0.1484, -2.2323]],
       device='cuda:0')

# 7. Activation and Parameter Offloading

In [9]:
def offload_activation(tensor):
    # Move tensor to CPU to save GPU memory
    return tensor.cpu()

def process_batch(data):
    # Offload some activations explicitly
    intermediate = model.layer1(data)
    intermediate = offload_activation(intermediate)
    intermediate = intermediate.cuda()  # Move back when needed
    output = model.layer2(intermediate)
    return output

# 8. Using a Leaner Optimizer

In [11]:
# instead of this
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

# use this
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
num_steps = NUM_EPOCHS * len(train_loader)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=num_steps)

# 9. Beyond the Basics

## 9.1) Memory Profiling and Cache Management

In [12]:
import torch

# print a detailed report of current GPU memory usage and fragmentation
print(torch.cuda.memory_summary(device=None, abbreviated=False))

# free up cached memory that’s no longer needed by PyTorch
torch.cuda.empty_cache()

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |  80896 B   |  84480 B   |  96256 B   |  15360 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |  80896 B   |  84480 B   |  96256 B   |  15360 B   |
|---------------------------------------------------------------------------|
| Active memory         |  80896 B   |  84480 B   |  96256 B   |  15360 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |  80896 B   |  84480 B   |  96256 B   |  15360 B   |
|---------------------------------------------------------------

## 9.2) JIT Compilation with TorchScript

In [14]:
import torch

# Suppose `model` is an instance of your PyTorch network.
scripted_model = torch.jit.script(model)

# Now, you can run the scripted model just like before.
output = scripted_model(input_tensor)