# Understanding ResNet Memory Footprint
In this notebook, we explore ResNet memory usage and the effect of gradient checkpointing. The following sections demonstrate:
- Installing required libraries
- Loading and preparing the Tiny ImageNet dataset
- Measuring GPU memory usage with varying batch sizes
- Comparing results with and without gradient checkpointing

## Install Dependencies

In [None]:
!pip install datasets
!pip install matplotlib


## Imports and Setup

In [None]:
import torch
import torch.nn as nn
import time
import numpy as np
import matplotlib.pyplot as plt

from torch.profiler import record_function
from datasets import load_dataset
from torchvision import transforms, models

# Set seed for reproducibility
torch.manual_seed(710)
np.random.seed(710)

# Detect number of available CUDA devices
print(f"Number of CUDA devices: {torch.cuda.device_count()}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Load Dataset

In [None]:
tiny_imagenet = load_dataset("Maysee/tiny-imagenet", split="train")
print(f"Sample record: {tiny_imagenet[0]}")
# Number of classes
num_classes = len(tiny_imagenet.features["label"].names)
print(f"Number of classes: {num_classes}")

## Prepare PyTorch Dataset

In [None]:
from PIL import Image

class TinyImageNet(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]["image"], self.dataset[idx]["label"]
        x = x.convert("RGB")
        if self.transform:
            x = self.transform(x)
        y = torch.tensor(y, dtype=torch.int64)
        return x, y

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

tiny_imagenet_torch = TinyImageNet(tiny_imagenet, transform=transform)
print(f"Sample torch dataset element shape: {tiny_imagenet_torch[0][0].shape}")

## Load Model to GPU

In [None]:
torch.cuda.memory._record_memory_history(max_entries=10000)

model_gpu_usage_before = torch.cuda.memory_allocated(device)
model = models.resnet18(pretrained=True).to(device)
model_gpu_usage_after = torch.cuda.memory_allocated(device)
model_gpu_usage = model_gpu_usage_after - model_gpu_usage_before
print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters())}")
print(f"Model GPU usage: {model_gpu_usage / 1024**2:.2f} MB")

del model


## Training and Memory Profiling
The following function trains the model for a few batches while profiling memory usage.

In [None]:
def fit(model, train_loader, val_loader, epochs=1, lr=0.001, break_after_num_batches=None, title=""):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
        record_shapes=True,
        with_stack=True,
        profile_memory=True
    ) as prof:
        for epoch in range(epochs):
            for batch_idx, batch in enumerate(train_loader):
                prof.step()
                inputs, labels = batch
                with record_function("to_device"):
                    inputs, labels = inputs.to(device), labels.to(device)
                with record_function("forward"):
                    outputs = model(inputs)
                with record_function("backward"):
                    criterion(outputs, labels).backward()
                with record_function("optimizer_step"):
                    optimizer.step()
                    optimizer.zero_grad()

                if break_after_num_batches is not None and batch_idx >= break_after_num_batches:
                    break

    prof.export_memory_timeline(f"{title}_memory.html", device="cuda:0")

def clear_cuda_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    import gc
    gc.collect()
    print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

def fit_helper(model_type, dataset, epochs, break_after_num_batches, batch_sizes, num_workers, title):
    for batch_size in batch_sizes:
        if model_type == "resnet18_without_checkpointing":
            model = models.resnet18(pretrained=True)
        elif model_type == "resnet18_with_checkpointing":
            model = ResnetCheckpointed()
        model.to(device)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

        oom_break = False
        try:
            fit(model, train_loader, val_loader, epochs=1, break_after_num_batches=break_after_num_batches)
            print(f"Processed for batch size {batch_size}")
        except RuntimeError as e:
            print(f"Runtime error for batch size {batch_size}: {e}")
            oom_break = True
        except torch.cuda.OutOfMemoryError:
            print(f"Out of memory for batch size {batch_size}")
            oom_break = True

        del model
        del train_loader
        del val_loader
        clear_cuda_memory()
        time.sleep(5)
        if oom_break:
            break

## Run Profiling Without Gradient Checkpointing

In [None]:
num_workers = 2
print(f"Number of workers: {num_workers}")
break_after_num_batches = 10
batch_sizes = [128, 256, 512, 1024, 2048, 4096, 8192]

fit_helper("resnet18_without_checkpointing", tiny_imagenet_torch, 1, break_after_num_batches, batch_sizes, num_workers, title="without_checkpointing")

## Define Model with Gradient Checkpointing

In [None]:
from torch.utils.checkpoint import checkpoint
import torchvision.models as models

class ResnetCheckpointed(nn.Module):
    def __init__(self):
        super(ResnetCheckpointed, self).__init__()
        self.model = models.resnet18(pretrained=True)
        
        # Store individual layers
        self.conv1 = self.model.conv1
        self.bn1 = self.model.bn1
        self.relu = self.model.relu
        self.maxpool = self.model.maxpool
        self.layer1 = self.model.layer1
        self.layer2 = self.model.layer2
        self.layer3 = self.model.layer3
        self.layer4 = self.model.layer4
        self.avgpool = self.model.avgpool
        self.fc = self.model.fc

    def forward(self, x):
        # Apply checkpointing to each layer
        x = checkpoint(self.conv1, x)
        x = checkpoint(self.bn1, x)
        x = self.relu(x)  # ReLU is memory-efficient, no need to checkpoint
        x = checkpoint(self.maxpool, x)
        x = checkpoint(self.layer1, x)
        x = checkpoint(self.layer2, x)
        x = checkpoint(self.layer3, x)
        x = checkpoint(self.layer4, x)
        x = checkpoint(self.avgpool, x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

## Run Profiling With Gradient Checkpointing

In [None]:
fit_helper("resnet18_with_checkpointing", tiny_imagenet_torch, 1, break_after_num_batches, batch_sizes, num_workers, title="with_checkpointing")

## Conclusions
- Without gradient checkpointing, we were able to fit up to a batch size of 1024 (before running out of memory).
- With gradient checkpointing, we could fit larger batch sizes (up to 4096) in the same GPU.

For memory-intensive tasks, gradient checkpointing can be a useful technique to trade compute for memory.