In [1]:
!pip install datasets
!pip install matplotlib
import torch
import torch.nn as nn
from torchvision import transforms, models
from torch.profiler import record_function
from datasets import load_dataset
import time
import numpy as np
import matplotlib.pyplot as plt

# set seed for reproducibility in torch, numpy and gpu

torch.manual_seed(710)
np.random.seed(710)

[0m

In [2]:
# Convert to torch dataset

class TinyImageNet(torch.utils.data.Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        x, y = self.dataset[idx]["image"], self.dataset[idx]["label"]
        # convert x to RGB
        x = x.convert("RGB")
        if self.transform:
            x = self.transform(x)
        y = torch.tensor(y, dtype=torch.int64)
        return x, y

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

tiny_imagenet = load_dataset("Maysee/tiny-imagenet", split="train")
tiny_imagenet_torch = TinyImageNet(tiny_imagenet, transform=transform)
num_classes = len(tiny_imagenet.features["label"].names)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


tiny_imagenet[0], device

({'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=64x64>,
  'label': 0},
 device(type='cuda', index=0))

In [3]:
torch.cuda.memory._record_memory_history(
       max_entries=10000
   )

model_gpu_usage_before = torch.cuda.memory_allocated(device)
model = models.resnet18(pretrained=True)
model.to(device)

model_gpu_usage_after = torch.cuda.memory_allocated(device)

model_gpu_usage = model_gpu_usage_after - model_gpu_usage_before

# print number of parameters in the model

print(f"Number of parameters in the model: {sum(p.numel() for p in model.parameters())}")
print(f"Model GPU usage: {model_gpu_usage / 1024**2:.2f} MB")

del model




Number of parameters in the model: 11689512
Model GPU usage: 44.69 MB


### Compare effect of checkpointing on memory usage for large batch sizes

In [4]:
def fit(model, train_loader, val_loader, epochs=1, lr=0.001, break_after_num_batches=None, title=""):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()
    total_times = []


    with torch.profiler.profile(
        activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(wait=0, warmup=0, active=6, repeat=1),
        record_shapes=True,
        with_stack=True,
        profile_memory=True
    ) as prof:

        for epoch in range(epochs):
            start_time = time.time()
            for batch_idx, batch in enumerate(train_loader):
                prof.step()


                inputs, labels = batch
                
                with record_function("to_device"):
                    inputs, labels = inputs.to(device), labels.to(device)

                with record_function("forward"):
                    outputs = model(inputs)
                
                with record_function("backward"):
                    criterion(outputs, labels).backward()
                
                with record_function("optimizer_step"):
                    optimizer.step()
                    optimizer.zero_grad()

                end_time = time.time()
                # total_times.append(end_time - start_time)
                
                if break_after_num_batches is not None and batch_idx >= break_after_num_batches:
                    break
                start_time = time.time()
    
    total_times = np.array(total_times)
    total_times = np.convolve(total_times, np.ones(rolling_window)/rolling_window, mode='valid')
    mean_time = total_times.mean()

                

    prof.export_memory_timeline(f"{title}_memory.html", device="cuda:0")
    return {
        "mean_time": mean_time,
        "total_times": total_times
    }

    

def clear_cuda_memory():
    # Clear memory caches
    torch.cuda.empty_cache()
    
    # Reset peak memory stats
    torch.cuda.reset_peak_memory_stats()
    
    # Clear memory allocated by PyTorch
    torch.cuda.synchronize()
    
    # Optional: Force garbage collection
    import gc
    gc.collect()
    
    # Print memory stats to verify
    print(f"Allocated memory: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"Cached memory: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")

def fit_helper(model_type, dataset, epochs, break_after_num_batches, title):
    clear_cuda_memory()
    for batch_size in batch_sizes:
        if model_type == "resnet18_without_checkpointing":
            model = models.resnet18(pretrained=True)
        elif model_type == "resnet18_with_checkpointing":
            model = ResnetCheckpointed()
        model.to(device)
        train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
        val_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
        title += f"_{batch_size}"
        oom_break = False

        try:
            times_dict = fit(model, train_loader, val_loader, epochs=1, break_after_num_batches=break_after_num_batches)
            print(f"Proccessed for batch size {batch_size}")

            plt.figure(figsize=(10, 5))
            plt.plot(batch_ids, data_load_times_dict['total_times'], label=f"load time avg {data_load_times_dict['mean_time']}", marker="o", alpha=0.5)
            plt.xlabel("Batch ID")
            plt.ylabel("Time (s)")
            plt.title(f"load times with {num_workers} workers with avg total time {avg_total_time} and {checkpointing}")
            plt.legend()
            plt.show()

        except torch.cuda.OutOfMemoryError:
            print(f"Out of memory for batch size {batch_size}")
            oom_break = True
        # clear memory
        del model
        del train_loader
        del val_loader
        clear_cuda_memory()
        time.sleep(10)
        if oom_break:
            break

### Break down memory usage in Resnet
- It takes memory to store the model, activations

In [5]:
from torch.utils.checkpoint import checkpoint_sequential
import torch.nn as nn
import torchvision.models as models

class ResnetCheckpointed(nn.Module):
    def __init__(self):
        super(ResnetCheckpointed, self).__init__()
        self.model = models.resnet18(pretrained=True)
        
        # Create a sequential container for the features
        self.features = nn.Sequential(
            self.model.conv1,
            self.model.bn1,
            self.model.relu,
            self.model.maxpool,
            self.model.layer1,
            self.model.layer2,
            self.model.layer3,
            self.model.layer4,
            self.model.avgpool
        )
        self.fc = self.model.fc
        
        # Number of segments to split the features into for checkpointing
        self.segments = 3

    def forward(self, x):
        # Apply checkpoint_sequential to features
        x = checkpoint_sequential(self.features, self.segments, x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

In [6]:
num_workers = 4
break_after_num_batches = 10
batch_indices = None
batch_sizes = [2048]


fit_helper("resnet18_with_checkpointing", tiny_imagenet_torch, 1, break_after_num_batches, title="with_checkpointing")



Allocated memory: 0.00 MB
Cached memory: 0.00 MB


  warn("Profiler won't be using warmup, this can skew profiler results")


Out of memory for batch size 2048
Allocated memory: 16.25 MB
Cached memory: 6272.00 MB


## Concluding remarks
- Without gradient checkpointing, we were able to fit upto batch size of 1024 in a single GPU 
- With gradient checkpointing, we were able to fit batch sizes of upto 4096 in a single GPU