# Understanding Checkpointed ResNet Time Footprint
## 1. Install and Imports

In [5]:
!pip install datasets
!pip install matplotlib
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.profiler import record_function
from datasets import load_dataset
import time
import numpy as np
import matplotlib.pyplot as plt
import torch.multiprocessing as mp
import torch.distributed as dist

# Set seeds for reproducibility
torch.manual_seed(710)
np.random.seed(710)


[0m

## 2. Prepare Dataset

In [6]:
class TinyImageNet(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
    def __len__(self):
        return len(self.dataset)
    def __getitem__(self, idx):
        x, y = self.dataset[idx]["image"], self.dataset[idx]["label"]
        x = x.convert("RGB")
        if self.transform:
            x = self.transform(x)
        y = torch.tensor(y, dtype=torch.int64)
        return x, y

## 3. Model Memory Usage Example

In [7]:
# Check baseline model memory usage
torch.cuda.memory._record_memory_history(max_entries=10000)

## 4. Define Checkpointed Model

In [8]:
from torch.utils.checkpoint import checkpoint

class ResnetCheckpointed(nn.Module):
    def __init__(self):
        super(ResnetCheckpointed, self).__init__()
        self.model = models.resnet18(pretrained=True)
        
        # Store individual layers
        self.conv1 = self.model.conv1
        self.bn1 = self.model.bn1
        self.relu = self.model.relu
        self.maxpool = self.model.maxpool
        self.layer1 = self.model.layer1
        self.layer2 = self.model.layer2
        self.layer3 = self.model.layer3
        self.layer4 = self.model.layer4
        self.avgpool = self.model.avgpool
        self.fc = self.model.fc

    def forward(self, x):
        # Apply checkpointing to each layer
        x = checkpoint(self.conv1, x)
        x = checkpoint(self.bn1, x)
        x = self.relu(x)  # ReLU is memory-efficient, no need to checkpoint
        x = checkpoint(self.maxpool, x)
        x = checkpoint(self.layer1, x)
        x = checkpoint(self.layer2, x)
        x = checkpoint(self.layer3, x)
        x = checkpoint(self.layer4, x)
        x = checkpoint(self.avgpool, x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


## 5. Training and Profiling

In [11]:
def fit(model, train_loader, val_loader, epochs=1, lr=0.001, break_after_num_batches=None, title=""):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    total_times = []

    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
        record_shapes=True,
        with_stack=True,
        profile_memory=True
    ) as prof:
        for epoch in range(epochs):
            start_time = time.time()
            for batch_idx, batch in enumerate(train_loader):
                prof.step()
                inputs, labels = batch
                with record_function("to_device"):
                    inputs, labels = inputs.to(device), labels.to(device)
                with record_function("forward"):
                    outputs = model(inputs)
                with record_function("backward"):
                    criterion(outputs, labels).backward()
                with record_function("optimizer_step"):
                    optimizer.step()
                    optimizer.zero_grad()
                end_time = time.time()
                total_times.append(end_time - start_time)
                if break_after_num_batches is not None and batch_idx >= break_after_num_batches:
                    break
                start_time = time.time()
    prof.export_memory_timeline(f"{title}_memory.html", device="cuda:0")
    batch_ids = np.arange(len(total_times))

    total_times = np.array(total_times)
    mean_time = np.round(np.mean(total_times), 2)

    # Plot results
    plt.figure(figsize=(10, 5))
    plt.plot(batch_ids, total_times, label=f"Mean time: {mean_time} s")
    plt.xlabel("Batch ID")
    plt.ylabel("Time (s)")
    plt.title(f"Training performance with gradient checkpointing")
    plt.legend()
    plt.show()


def clear_cuda_memory():
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.synchronize()
    import gc
    gc.collect()
    print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

def fit_helper(model_type, dataset, epochs, break_after_num_batches, num_workers, batch_sizes, title, device_id):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    tiny_imagenet = load_dataset("Maysee/tiny-imagenet", split="train")
    tiny_imagenet_torch = TinyImageNet(tiny_imagenet, transform=transform)
    num_classes = len(tiny_imagenet.features["label"].names)

    device = torch.device(f"cuda:{device_id}")
    clear_cuda_memory()
    for batch_size in batch_sizes:
        if model_type == "resnet18_without_checkpointing":
            net = models.resnet18(pretrained=True)
        else:
            net = ResnetCheckpointed()
        net.to(device)
        train_sampler = DistributedSampler(dataset, num_replicas=2, rank=device_id)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, sampler=train_sampler)
        val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, sampler=train_sampler)
        oom_break = False
        try:
            times_dict = fit(net, train_loader, val_loader, epochs=epochs, break_after_num_batches=break_after_num_batches)
            print(f"Processed for batch size {batch_size}")
        except torch.cuda.OutOfMemoryError:
            print(f"Out of memory for batch size {batch_size}")
            oom_break = True
        del net
        del train_loader
        del val_loader
        clear_cuda_memory()
        time.sleep(10)
        if oom_break:
            break

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def train(rank, world_size):
    setup(rank, world_size)
    clear_cuda_memory()
    num_workers = 2
    break_after_num_batches = 10
    batch_sizes = [2500]
    fit_helper(
        model_type="resnet18_with_checkpointing", 
        dataset=tiny_imagenet_torch, 
        epochs=1, 
        break_after_num_batches=break_after_num_batches, 
        num_workers=num_workers, 
        batch_sizes=batch_sizes,
        title="with_checkpointing_2_gpus",
        device_id=rank
    )



## 6. Run Experiment

In [12]:
clear_cuda_memory()
!nvidia-smi

Allocated memory: 0.00 MB
Cached memory: 0.00 MB
Thu Feb 13 01:39:17 2025       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.4     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Quadro RTX 5000                Off | 00000000:1E:00.0 Off |                  Off |
| 34%   31C    P8              23W / 230W |    124MiB / 16384MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
|   1  Quadro RTX 5

In [13]:
mp.spawn(train, args=(2), nprocs=2, join=True)

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/root/miniconda3/envs/py3.10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/root/miniconda3/envs/py3.10/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/root/miniconda3/envs/py3.10/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'train' on <module '__main__' (built-in)>
W0213 01:41:06.032000 140074185762624 torch/multiprocessing/spawn.py:146] Terminating process 3722 via signal SIGTERM


ProcessExitedException: process 1 terminated with exit code 1

## 7. Conclusions
- Checkpointing can help fit larger batch sizes with limited GPU memory.
- This notebook explores how it impacts time and memory usage.