In [None]:
import math
import typing

import matplotlib.pyplot as plt
import numpy as np
from sklearn.linear_model import LinearRegression
import torch
import torchvision.transforms.v2 as transforms
from torch import nn, optim
from torchvision import datasets, models

import cifar_resnet

In [None]:
device = torch.device(
    #f'cuda:{torch.cuda.device_count() - 1}' if torch.cuda.is_available() else 'cpu'
    'cuda:0' if torch.cuda.is_available() else 'cpu'
)
capability = torch.cuda.get_device_capability() if device.type == 'cuda' else None
torch.jit.enable_onednn_fusion(True)
if device.type == 'cuda':
    torch.cuda.set_device(device)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
mem_info = torch.cuda.mem_get_info(device=device) # global (free, total) GPU memor
print(f'Device: {device}, Type: {device.type}, Compute_Capability: {capability}')

In [None]:
if device.type == 'cuda':
    GPU_info = {
        'device_name': torch.cuda.get_device_name(device=device),
        'mem_info': torch.cuda.mem_get_info(device=device),
    }
    print(f'GPU Name: {GPU_info['device_name']}, Memory (free, total): {GPU_info['mem_info']}')

In [None]:
root = '~/.pytorch/dataset'
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
transform_pre = nn.Sequential(
    transforms.ToDtype(torch.uint8, scale=True),
    )
transform_post = nn.Sequential(
    transforms.ToImage(),
    transforms.ToDtype(dtype=torch.float32, scale=True),
    transforms.Normalize(mean=mean, std=std),
)
transform = {
    'train': nn.Sequential(
        transform_pre,
        transforms.RandomResizedCrop(size=32, scale=(.8, 1), ratio=(.8, 1.2)),
        transforms.RandomHorizontalFlip(),
        transform_post,
    ),
    'eval': nn.Sequential(
        transform_pre,
        transform_post,
    ),
}

In [None]:
root = '~/.pytorch/datasets'
dataset = {
    'train': datasets.CIFAR100(root=root, train=True, transform=transform['train'], download=True),
    'eval': datasets.CIFAR100(root=root, train=False, transform=transform['eval'], download=True),
}
model = cifar_resnet.CIFAR_ResNet(n=3, num_classes=100, p=0.2).to(device, memory_format=torch.channels_last)

In [None]:
def train_a_batch(model, dataset, dataloader, criterion, optimizer):
    # record
    record_loss, record_acc = 0, 0
    # train
    model.train()
    for i, data in enumerate(dataloader):
        # load data
        inputs = data[0].to(device, non_blocking=True, memory_format=torch.channels_last)
        labels = data[1].to(device, non_blocking=True)
        # compute
        '''
        optimizer.zero_grad()
        with torch.autocast(device.type, enabled=AUTOCAST_FLAG):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        '''
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        # record
        #record_loss += loss.item()
        #record_acc += (labels.argmax(dim=1) == outputs.argmax(dim=1)).sum().item()
        break
    # results
    #record_loss /= len(dataloader) # mean loss
    #record_acc /= (len(dataset) - (len(dataset) % batch_size)) if AUTOCAST_FLAG else len(dataset)
    return record_loss, record_acc

In [None]:
def timed_mem(fn):
    torch.cuda.reset_peak_memory_stats(device=device)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000, torch.cuda.max_memory_allocated(device=device)

### Get Mem Usage, BS = [128, 256, 384, 512]

In [None]:
batch_size_ls = list(range(64, 512+1, 64))
mem_cost_ls = []

In [None]:
for batch_size in batch_size_ls:

    ## configuration
    dataloader = {
        'train': torch.utils.data.DataLoader(
            dataset['train'],
            batch_size=batch_size,
            shuffle=True,
        ),
        'eval': torch.utils.data.DataLoader(
            dataset['eval'],
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        ),
    }
    
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
    optimizer = optim.SGD(model.parameters(), lr=1e-1 * (batch_size / 128), momentum=0.9, weight_decay=1e-4)
    #optimizer = optim.AdamW(model.parameters(), lr=1e-3 * (batch_size / 1024), weight_decay=0.05)

    ## training
    result, time_cost, mem_cost = timed_mem(
        lambda: train_a_batch(model, dataset['train'], dataloader['train'], criterion, optimizer)
    )

    ## result
    mem_cost_ls.append(mem_cost)
    print(batch_size, mem_cost_ls[-1])

### Predict Mem, Linear Regression

In [None]:
reg_model = LinearRegression().fit(np.array(batch_size_ls).reshape(-1, 1), np.array(mem_cost_ls))

In [None]:
print(reg_model.intercept_, reg_model.coef_)

### Get Predict Max Mem

In [None]:
guess_max = 15000

In [None]:
predict_bs_ls = np.arange(1, guess_max + 1)
predict_mem_ls = reg_model.predict(predict_bs_ls.reshape(-1, 1))

In [None]:
predict_bs = predict_mem_ls[predict_mem_ls <= GPU_info['mem_info'][0]].argmax() + 1
predict_mem = predict_mem_ls[predict_mem_ls <= GPU_info['mem_info'][0]].max()

In [None]:
print(f'GPU mem: {GPU_info['mem_info'][0]}')
print(f'predict max bs: {predict_bs}')
print(f'predict max mem: {predict_mem}')

### Get Measure Max Mem

In [None]:
batch_size = int(predict_bs)

## configuration
dataloader = {
    'train': torch.utils.data.DataLoader(
        dataset['train'],
        batch_size=batch_size,
        shuffle=True,
    ),
    'eval': torch.utils.data.DataLoader(
        dataset['eval'],
        batch_size=batch_size,
        shuffle=False,
        drop_last=False,
    ),
}
    
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.SGD(model.parameters(), lr=1e-1 * (batch_size / 128), momentum=0.9, weight_decay=1e-4)
#optimizer = optim.AdamW(model.parameters(), lr=1e-3 * (batch_size / 1024), weight_decay=0.05)

## training
result, time_cost, mem_cost = timed_mem(
    lambda: train_a_batch(model, dataset['train'], dataloader['train'], criterion, optimizer)
)

## result
print(batch_size, mem_cost)

In [None]:
measure_mem = mem_cost

### Summary

In [None]:
print(f'GPU Name: {GPU_info['device_name']}, Memory (free, total): {GPU_info['mem_info']}')
print(f'Predict batch size: {predict_bs}')
print(f'Predict memory usage: {predict_mem} Byte == {predict_mem // 2**20: g} MiB')
print(f'Measure memory usage: {measure_mem} Byte == {measure_mem // 2**20: g} MiB')
print(f'Gap of memory usage: {predict_mem - measure_mem} Byte == {(predict_mem - measure_mem) // 2**20: g} MiB')