In [None]:
import math
import typing

import numpy as np
import torch
import torchvision.transforms.v2 as transforms
from torch import nn, optim
from torchvision import datasets, models

In [None]:
device = torch.device(
    #f'cuda:{torch.cuda.device_count() - 1}' if torch.cuda.is_available() else 'cpu'
    'cuda:0' if torch.cuda.is_available() else 'cpu'
)
capability = torch.cuda.get_device_capability() if device.type == 'cuda' else None
torch.jit.enable_onednn_fusion(True)
if device.type == 'cuda':
    torch.cuda.set_device(device)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
mem_info = torch.cuda.mem_get_info(device=device) # global (free, total) GPU memor
print(f'Device: {device}, Type: {device.type}, Compute_Capability: {capability}')

In [None]:
if device.type == 'cuda':
    GPU_info = {
        'device_name': torch.cuda.get_device_name(device=device),
        'mem_info': torch.cuda.mem_get_info(device=device),
    }
    print(f'GPU Name: {GPU_info['device_name']}, Memory (free, total): {GPU_info['mem_info']}')

In [None]:
root = '~/.pytorch/dataset'
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
transform_pre = nn.Sequential(
    transforms.ToDtype(torch.uint8, scale=True),
    )
transform_post = nn.Sequential(
    transforms.ToImage(),
    transforms.ToDtype(dtype=torch.float32, scale=True),
    transforms.Normalize(mean=mean, std=std),
)
transform = {
    'train': nn.Sequential(
        transform_pre,
        transforms.RandomResizedCrop(size=224, scale=(.8, 1), ratio=(.8, 1.2)),
        transforms.RandomHorizontalFlip(),
        transform_post,
    ),
    'eval': nn.Sequential(
        transform_pre,
        #transforms.Resize(size=256),
        #transforms.CenterCrop(size=size),
        transform_post,
    ),
}

In [None]:
root = '~/ssd/imagenet/'
dataset = {
    'train': datasets.ImageFolder(root=root+'train/', transform=transform['train']),
    'eval': datasets.ImageFolder(root=root+'val/', transform=transform['eval']),
}
model = models.resnet18().to(device, memory_format=torch.channels_last)

In [None]:
def train_a_batch(model, dataset, dataloader, criterion, optimizer):
    # record
    record_loss, record_acc = 0, 0
    # train
    model.train()
    for i, data in enumerate(dataloader):
        # load data
        inputs = data[0].to(device, non_blocking=True, memory_format=torch.channels_last)
        labels = data[1].to(device, non_blocking=True)
        # compute
        '''
        optimizer.zero_grad()
        with torch.autocast(device.type, enabled=AUTOCAST_FLAG):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        '''
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # record
        #record_loss += loss.item()
        #record_acc += (labels.argmax(dim=1) == outputs.argmax(dim=1)).sum().item()
        break
    # results
    #record_loss /= len(dataloader) # mean loss
    #record_acc /= (len(dataset) - (len(dataset) % batch_size)) if AUTOCAST_FLAG else len(dataset)
    return record_loss, record_acc

In [None]:
def timed_mem(fn):
    torch.cuda.reset_peak_memory_stats(device=device)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000, torch.cuda.max_memory_allocated(device=device)

## BS = 1

In [None]:
batch_size = 1

dataloader = {
    'train': torch.utils.data.DataLoader(
        dataset['train'],
        batch_size=batch_size,
        shuffle=True,
        #num_workers=num_workers,
        #collate_fn=collate_fn if MIX_FLAG else None,
        #pin_memory=True,
        #drop_last=AUTOCAST_FLAG,
        #persistent_workers=True,
    ),
    'eval': torch.utils.data.DataLoader(
        dataset['eval'],
        batch_size=batch_size,
        shuffle=False,
        #num_workers=num_workers,
        #pin_memory=True,
        drop_last=False,
        #persistent_workers=True,
    ),
}

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(model.parameters(), lr=1e-1 * (batch_size / 128), momentum=0.9, weight_decay=1e-4)
#optimizer = optim.AdamW(model.parameters(), lr=1e-3 * (batch_size / 1024), weight_decay=0.05)

In [None]:
result, time_cost, mem_cost = timed_mem(
    lambda: train_a_batch(model, dataset['train'], dataloader['train'], criterion, optimizer)
)

In [None]:
mem_bs_1 = mem_cost
print(mem_bs_1)

## BS = 2

In [None]:
batch_size = 2

dataloader = {
    'train': torch.utils.data.DataLoader(
        dataset['train'],
        batch_size=batch_size,
        shuffle=True,
        #num_workers=num_workers,
        #collate_fn=collate_fn if MIX_FLAG else None,
        #pin_memory=True,
        #drop_last=AUTOCAST_FLAG,
        #persistent_workers=True,
    ),
    'eval': torch.utils.data.DataLoader(
        dataset['eval'],
        batch_size=batch_size,
        shuffle=False,
        #num_workers=num_workers,
        #pin_memory=True,
        drop_last=False,
        #persistent_workers=True,
    ),
}

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(model.parameters(), lr=1e-1 * (batch_size / 128), momentum=0.9, weight_decay=1e-4)
#optimizer = optim.AdamW(model.parameters(), lr=1e-3 * (batch_size / 1024), weight_decay=0.05)

In [None]:
result, time_cost, mem_cost = timed_mem(
    lambda: train_a_batch(model, dataset['train'], dataloader['train'], criterion, optimizer)
)

In [None]:
mem_bs_2 = mem_cost
print(mem_bs_2)

## Predict ${BS}_{MAX}$

In [None]:
mem_cost_gap = mem_bs_2 - mem_bs_1
print(mem_cost_gap)

In [None]:
batch_size_max = (GPU_info['mem_info'][0] - mem_bs_1) // mem_cost_gap + 1
print(batch_size_max)

In [None]:
mem_bs_max_predict = mem_bs_1 + (batch_size_max - 1) * mem_cost_gap
print(mem_bs_max_predict)

## BS = MAX

In [None]:
batch_size = batch_size_max

dataloader = {
    'train': torch.utils.data.DataLoader(
        dataset['train'],
        batch_size=batch_size,
        shuffle=True,
        #num_workers=num_workers,
        #collate_fn=collate_fn if MIX_FLAG else None,
        #pin_memory=True,
        #drop_last=AUTOCAST_FLAG,
        #persistent_workers=True,
    ),
    'eval': torch.utils.data.DataLoader(
        dataset['eval'],
        batch_size=batch_size,
        shuffle=False,
        #num_workers=num_workers,
        #pin_memory=True,
        drop_last=False,
        #persistent_workers=True,
    ),
}

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(model.parameters(), lr=1e-1 * (batch_size / 128), momentum=0.9, weight_decay=1e-4)
#optimizer = optim.AdamW(model.parameters(), lr=1e-3 * (batch_size / 1024), weight_decay=0.05)

In [None]:
result, time_cost, mem_cost = timed_mem(
    lambda: train_a_batch(model, dataset['train'], dataloader['train'], criterion, optimizer)
)

In [None]:
print(batch_size_max)

In [None]:
mem_bs_max_measure = mem_cost
print(mem_bs_max_measure)

In [None]:
print(f'GAP: {(GPU_info['mem_info'][0] - mem_bs_max_measure) / 2**20} MiB')

## Summary

In [None]:
print(f'GPU Name: {GPU_info['device_name']}, Memory (free, total): {GPU_info['mem_info']}')
print(f'Predict batch size: {batch_size_max}')
print(f'Predict memory usage: {mem_bs_max_predict} Byte == {mem_bs_max_predict // 2**20: g} MiB')
print(f'Measure memory usage: {mem_bs_max_measure} Byte == {mem_bs_max_measure // 2**20: g} MiB')
print(f'Gap of memory usage: {mem_bs_max_predict - mem_bs_max_measure} Byte == {(mem_bs_max_predict - mem_bs_max_measure) // 2**20: g} MiB')