In [34]:
%cd ../..

/media/atem/Data/HSE_videos/4_DLA


In [35]:
import torch
from functools import partial
from omegaconf import OmegaConf
from hydra.utils import instantiate
from hydra import initialize, compose

In [36]:
GPU_MEMORY_GB = 24  # L4 GPU = 24 GB
SR = 16_000         # sample rate
DURATION = 2.2      # seconds
N_FFT = 6144
HOP_LENGTH = 1024

In [None]:
from src.model import DTTNetModel

model = DTTNetModel( 
    fc_dim=N_FFT // 2 + 1,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_heads=2,
)
model.eval()

DTTNetModel(
  (encoder): Encoder(
    (init_conv): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
    (encoder_layers): ModuleList(
      (0): EncoderBlock(
        (tfc_tdf): TFC_TDF_Block(
          (conv1): Sequential(
            (0): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              (2): GELU(approximate='none')
            )
            (1): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              (2): GELU(approximate='none')
            )
            (2): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
           

In [38]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
params_bytes_fp32 = num_params * 4
optimizer_overhead = params_bytes_fp32 * 2  # Adam m + v
static_bytes = params_bytes_fp32 + optimizer_overhead

print("ðŸ”¹ Model summary:")
print(f"  â€¢ Trainable params: {num_params:,}")
print(f"  â€¢ Model params:     {params_bytes_fp32 / 1024**2:.2f} MB")
print(f"  â€¢ Static (model + Adam): {static_bytes / 1024**2:.2f} MB\n")

ðŸ”¹ Model summary:
  â€¢ Trainable params: 37,704,843
  â€¢ Model params:     143.83 MB
  â€¢ Static (model + Adam): 431.50 MB



In [39]:
num_samples = int(SR * DURATION)
F = N_FFT // 2 + 1
T = (num_samples - N_FFT) // HOP_LENGTH + 1
B = 1

spectrogram = torch.randn(B, F, T)
phase = torch.randn(B, F, T)
audio_len = num_samples

print("ðŸ”¹ Input shape info:")
print(f"  â€¢ FFT bins (F):     {F}")
print(f"  â€¢ Time frames (T):  {T}")
print(f"  â€¢ Input tensor:     {tuple(spectrogram.shape)}\n")

ðŸ”¹ Input shape info:
  â€¢ FFT bins (F):     3073
  â€¢ Time frames (T):  29
  â€¢ Input tensor:     (1, 3073, 29)



In [40]:
def hook_fn(activation_list, module, input, output):
    def sizeof(obj):
        if isinstance(obj, torch.Tensor):
            return obj.numel() * obj.element_size()
        elif isinstance(obj, (list, tuple, set)):
            return sum(sizeof(o) for o in obj)
        elif isinstance(obj, dict):
            return sum(sizeof(v) for v in obj.values())
        else:
            return 0
    activation_list.append(sizeof(output))

activation_sizes = []
hooks = [m.register_forward_hook(partial(hook_fn, activation_sizes)) for m in model.modules()]

In [41]:
with torch.no_grad():
    _ = model(spectrogram, phase, audio_len)

for h in hooks:
    h.remove()

activations_bytes_fp32 = sum(activation_sizes)
activations_bytes_bf16 = activations_bytes_fp32 / 2

print("ðŸ”¹ Activations per sample:")
print(f"  â€¢ FP32: {activations_bytes_fp32 / 1024**2:.2f} MB")
print(f"  â€¢ BF16: {activations_bytes_bf16 / 1024**2:.2f} MB\n")

ðŸ”¹ Activations per sample:
  â€¢ FP32: 1551.78 MB
  â€¢ BF16: 775.89 MB



In [42]:
GPU_BYTES = GPU_MEMORY_GB * 1024**3

def compute_batch_limits(total_mem, static_mem, per_sample_mem):
    b_max = (total_mem - static_mem) // per_sample_mem
    b_safe = int(b_max * 0.9)  # 10% safety margin
    return int(b_max), b_safe

b_max_fp32, b_safe_fp32 = compute_batch_limits(GPU_BYTES, static_bytes, activations_bytes_fp32)
b_max_bf16, b_safe_bf16 = compute_batch_limits(GPU_BYTES, static_bytes, activations_bytes_bf16)

In [43]:
print("============================================================")
print(f"ðŸ“‹ MEMORY SUMMARY (GPU = {GPU_MEMORY_GB} GB)\n")
print(f"ðŸ”¸ Static model + optimizer: {static_bytes / 1024**2:.2f} MB")
print("ðŸ”¸ Activations per sample:")
print(f"    â€¢ FP32: {activations_bytes_fp32 / 1024**2:.2f} MB")
print(f"    â€¢ BF16: {activations_bytes_bf16 / 1024**2:.2f} MB\n")

print(f"ðŸ”¹ MAX batch size (theoretical):")
print(f"    â€¢ FP32: {b_max_fp32}")
print(f"    â€¢ BF16: {b_max_bf16}\n")

print(f"ðŸ”¹ Recommended safe batch size (~90% VRAM):")
print(f"    â€¢ FP32: {b_safe_fp32}")
print(f"    â€¢ BF16: {b_safe_bf16}")
print("============================================================")

ðŸ“‹ MEMORY SUMMARY (GPU = 24 GB)

ðŸ”¸ Static model + optimizer: 431.50 MB
ðŸ”¸ Activations per sample:
    â€¢ FP32: 1551.78 MB
    â€¢ BF16: 775.89 MB

ðŸ”¹ MAX batch size (theoretical):
    â€¢ FP32: 15
    â€¢ BF16: 31

ðŸ”¹ Recommended safe batch size (~90% VRAM):
    â€¢ FP32: 13
    â€¢ BF16: 27


## Calculate with checkpointing

In [None]:
from src.model import DTTNetModel

model = DTTNetModel( 
    fc_dim=N_FFT // 2 + 1,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_heads=2,
    use_checkpoints=True,
)
model.eval()