In [2]:
%cd ../..

/media/atem/Data/HSE_videos/4_DLA/hw_2_SeppechSep/git_speech_separation


In [None]:
import torch
from functools import partial

In [5]:
GPU_MEMORY_GB = 24  # L4 GPU = 24 GB
SR = 16_000         # sample rate
DURATION = 2.2      # seconds
N_FFT = 6144
HOP_LENGTH = 1024

In [None]:
from src.model import DTTNetModel

model = DTTNetModel( 
    fc_dim=N_FFT // 2 + 1,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_heads=2,
)
model.eval()

DTTNetModel(
  (encoder): Encoder(
    (init_conv): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
    (encoder_layers): ModuleList(
      (0): EncoderBlock(
        (tfc_tdf): TFC_TDF_Block(
          (conv1): Sequential(
            (0): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              (2): GELU(approximate='none')
            )
            (1): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              (2): GELU(approximate='none')
            )
            (2): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
           

In [11]:
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
params_bytes_fp32 = num_params * 4
optimizer_overhead = params_bytes_fp32 * 2  # Adam m + v
static_bytes = params_bytes_fp32 + optimizer_overhead

print("üîπ Model summary:")
print(f"  ‚Ä¢ Trainable params: {num_params:,}")
print(f"  ‚Ä¢ Model params:     {params_bytes_fp32 / 1024**2:.2f} MB")
print(f"  ‚Ä¢ Static (model + Adam): {static_bytes / 1024**2:.2f} MB\n")

üîπ Model summary:
  ‚Ä¢ Trainable params: 37,704,843
  ‚Ä¢ Model params:     143.83 MB
  ‚Ä¢ Static (model + Adam): 431.50 MB



In [8]:
num_samples = int(SR * DURATION)
F = N_FFT // 2 + 1
T = (num_samples - N_FFT) // HOP_LENGTH + 1
B = 1

spectrogram = torch.randn(B, F, T)
phase = torch.randn(B, F, T)
audio_len = num_samples

print("üîπ Input shape info:")
print(f"  ‚Ä¢ FFT bins (F):     {F}")
print(f"  ‚Ä¢ Time frames (T):  {T}")
print(f"  ‚Ä¢ Input tensor:     {tuple(spectrogram.shape)}\n")

üîπ Input shape info:
  ‚Ä¢ FFT bins (F):     3073
  ‚Ä¢ Time frames (T):  29
  ‚Ä¢ Input tensor:     (1, 3073, 29)



In [16]:
def hook_fn(activation_list, module, input, output):
    def sizeof(obj):
        if isinstance(obj, torch.Tensor):
            return obj.numel() * obj.element_size()
        elif isinstance(obj, (list, tuple, set)):
            return sum(sizeof(o) for o in obj)
        elif isinstance(obj, dict):
            return sum(sizeof(v) for v in obj.values())
        else:
            return 0
    activation_list.append(sizeof(output))

activation_sizes = []
hooks = [m.register_forward_hook(partial(hook_fn, activation_sizes)) for m in model.modules()]

In [17]:
with torch.no_grad():
    _ = model(spectrogram, phase, audio_len)

for h in hooks:
    h.remove()

activations_bytes_fp32 = sum(activation_sizes)
activations_bytes_bf16 = activations_bytes_fp32 / 2

print("üîπ Activations per sample:")
print(f"  ‚Ä¢ FP32: {activations_bytes_fp32 / 1024**2:.2f} MB")
print(f"  ‚Ä¢ BF16: {activations_bytes_bf16 / 1024**2:.2f} MB\n")

üîπ Activations per sample:
  ‚Ä¢ FP32: 1551.78 MB
  ‚Ä¢ BF16: 775.89 MB



In [18]:
GPU_BYTES = GPU_MEMORY_GB * 1024**3

def compute_batch_limits(total_mem, static_mem, per_sample_mem):
    b_max = (total_mem - static_mem) // per_sample_mem
    b_safe = int(b_max * 0.9)  # 10% safety margin
    return int(b_max), b_safe

b_max_fp32, b_safe_fp32 = compute_batch_limits(GPU_BYTES, static_bytes, activations_bytes_fp32)
b_max_bf16, b_safe_bf16 = compute_batch_limits(GPU_BYTES, static_bytes, activations_bytes_bf16)

In [19]:
print("============================================================")
print(f"üìã MEMORY SUMMARY (GPU = {GPU_MEMORY_GB} GB)\n")
print(f"üî∏ Static model + optimizer: {static_bytes / 1024**2:.2f} MB")
print("üî∏ Activations per sample:")
print(f"    ‚Ä¢ FP32: {activations_bytes_fp32 / 1024**2:.2f} MB")
print(f"    ‚Ä¢ BF16: {activations_bytes_bf16 / 1024**2:.2f} MB\n")

print(f"üîπ MAX batch size (theoretical):")
print(f"    ‚Ä¢ FP32: {b_max_fp32}")
print(f"    ‚Ä¢ BF16: {b_max_bf16}\n")

print(f"üîπ Recommended safe batch size (~90% VRAM):")
print(f"    ‚Ä¢ FP32: {b_safe_fp32}")
print(f"    ‚Ä¢ BF16: {b_safe_bf16}")
print("============================================================")

üìã MEMORY SUMMARY (GPU = 24 GB)

üî∏ Static model + optimizer: 431.50 MB
üî∏ Activations per sample:
    ‚Ä¢ FP32: 1551.78 MB
    ‚Ä¢ BF16: 775.89 MB

üîπ MAX batch size (theoretical):
    ‚Ä¢ FP32: 15
    ‚Ä¢ BF16: 31

üîπ Recommended safe batch size (~90% VRAM):
    ‚Ä¢ FP32: 13
    ‚Ä¢ BF16: 27


## Calculate with checkpointing

In [6]:
from src.model import DTTNetModel

model = DTTNetModel( 
    fc_dim=N_FFT // 2 + 1,
    n_fft=N_FFT,
    hop_length=HOP_LENGTH,
    n_heads=2,
    use_checkpoints=True,
)
model.eval()

DTTNetModel(
  (encoder): Encoder(
    (init_conv): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
    (encoder_layers): ModuleList(
      (0): EncoderBlock(
        (tfc_tdf): TFC_TDF_Block(
          (conv1): Sequential(
            (0): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              (2): GELU(approximate='none')
            )
            (1): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
              (2): GELU(approximate='none')
            )
            (2): Sequential(
              (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
              (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
           

In [34]:
from functools import partial
import torch
from src.model.DTTNet.blocks.block_tfc_tdf import TFC_TDF_Block  # –µ—Å–ª–∏ –ø—É—Ç—å –¥—Ä—É–≥–æ–π ‚Äî –ø–æ–ø—Ä–∞–≤—å

activation_sizes = []
hooks = []

def hook_fn(activation_list, module, input, output):
    """–ü–æ–¥—Å—á—ë—Ç –ø–∞–º—è—Ç–∏ –∞–∫—Ç–∏–≤–∞—Ü–∏–π –≤ –±–∞–π—Ç–∞—Ö"""
    def sizeof(obj):
        if isinstance(obj, torch.Tensor):
            return obj.numel() * obj.element_size()
        elif isinstance(obj, (list, tuple, set)):
            return sum(sizeof(o) for o in obj)
        elif isinstance(obj, dict):
            return sum(sizeof(v) for v in obj.values())
        else:
            return 0
    activation_list.append(sizeof(output))

# 1Ô∏è‚É£ –°–æ–±–µ—Ä—ë–º –≤—Å–µ conv1 / conv2 –∏–∑ TFC_TDF_Block (–∏—Ö –Ω—É–∂–Ω–æ –∏—Å–∫–ª—é—á–∏—Ç—å)
excluded_modules = set()
for m in model.modules():
    if isinstance(m, TFC_TDF_Block):
        excluded_modules.add(m.conv1)
        excluded_modules.add(m.conv2)
        # –¥–æ–±–∞–≤–∏–º –∏—Ö –¥–æ—á–µ—Ä–Ω–∏–µ —Å–ª–æ–∏ ‚Äî conv2d, norm –∏ gelu –≤–Ω—É—Ç—Ä–∏
        for sub in m.conv1.modules():
            excluded_modules.add(sub)
        for sub in m.conv2.modules():
            excluded_modules.add(sub)

# 2Ô∏è‚É£ –ù–∞–≤–µ—à–∏–≤–∞–µ–º —Ö—É–∫–∏ –Ω–∞ –≤—Å—ë, –∫—Ä–æ–º–µ –∏—Å–∫–ª—é—á—ë–Ω–Ω—ã—Ö conv-–±–ª–æ–∫–æ–≤
for m in model.modules():
    if m in excluded_modules:
        continue  # –ø—Ä–æ–ø—É—Å–∫–∞–µ–º conv1 –∏ conv2 –≤–Ω—É—Ç—Ä–∏ TFC_TDF_Block
    hooks.append(m.register_forward_hook(partial(hook_fn, activation_sizes)))

# 3Ô∏è‚É£ –ü—Ä–æ–≥–æ–Ω
with torch.no_grad():
    _ = model(spectrogram, phase, audio_len)

# 4Ô∏è‚É£ –£–±–∏—Ä–∞–µ–º —Ö—É–∫–∏
for h in hooks:
    h.remove()

# 5Ô∏è‚É£ –°—á–∏—Ç–∞–µ–º
activations_bytes_fp32_ckpt = sum(activation_sizes)
activations_bytes_bf16_ckpt = activations_bytes_fp32_ckpt / 2

print("üîπ Activations (excluding conv1/conv2 inside TFC_TDF_Block):")
print(f"  ‚Ä¢ FP32: {activations_bytes_fp32_ckpt / 1024**2:.2f} MB")
print(f"  ‚Ä¢ BF16: {activations_bytes_bf16_ckpt / 1024**2:.2f} MB")

üîπ Activations (excluding conv1/conv2 inside TFC_TDF_Block):
  ‚Ä¢ FP32: 498.10 MB
  ‚Ä¢ BF16: 249.05 MB


In [35]:
GPU_BYTES = GPU_MEMORY_GB * 1024**3

def compute_batch_limits(total_mem, static_mem, per_sample_mem):
    b_max = (total_mem - static_mem) // per_sample_mem
    b_safe = int(b_max * 0.9)  # 10% safety margin
    return int(b_max), b_safe

b_max_fp32_ckpt, b_safe_fp32_ckpt = compute_batch_limits(GPU_BYTES, static_bytes, activations_bytes_fp32_ckpt)
b_max_bf16_ckpt, b_safe_bf16_ckpt = compute_batch_limits(GPU_BYTES, static_bytes, activations_bytes_bf16_ckpt)

In [36]:
print("============================================================")
print(f"üìã MEMORY SUMMARY (GPU = {GPU_MEMORY_GB} GB)\n")
print(f"üî∏ Static model + optimizer: {static_bytes / 1024**2:.2f} MB")
print("üî∏ Activations per sample:")
print(f"    ‚Ä¢ FP32: {activations_bytes_fp32_ckpt / 1024**2:.2f} MB")
print(f"    ‚Ä¢ BF16: {activations_bytes_bf16_ckpt / 1024**2:.2f} MB\n")

print(f"üîπ MAX batch size (theoretical):")
print(f"    ‚Ä¢ FP32: {b_max_fp32_ckpt}")
print(f"    ‚Ä¢ BF16: {b_max_bf16_ckpt}\n")

print(f"üîπ Recommended safe batch size (~90% VRAM):")
print(f"    ‚Ä¢ FP32: {b_safe_fp32_ckpt}")
print(f"    ‚Ä¢ BF16: {b_safe_bf16_ckpt}")
print("============================================================")

üìã MEMORY SUMMARY (GPU = 24 GB)

üî∏ Static model + optimizer: 431.50 MB
üî∏ Activations per sample:
    ‚Ä¢ FP32: 498.10 MB
    ‚Ä¢ BF16: 249.05 MB

üîπ MAX batch size (theoretical):
    ‚Ä¢ FP32: 48
    ‚Ä¢ BF16: 96

üîπ Recommended safe batch size (~90% VRAM):
    ‚Ä¢ FP32: 43
    ‚Ä¢ BF16: 86
