In [1]:
import sys

sys.path.append("../..")

In [2]:
import torch
import torch.nn as nn

import time
from tqdm.auto import trange

from experiments.min_gru import MinGTCRN, MinMPNet
from models.gtcrn import GTCRN
from models.mpnet import MPNet
from utils import count_parameters, load_config

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
gtcrn = GTCRN()
min_gtcrn = MinGTCRN()

config = load_config("../../models/mpnet/config.json")
mpnet = MPNet(config, num_tsblocks=4)
min_mpnet = MinMPNet(config, num_tsblocks=4)

print(f"GTCRN params:    {count_parameters(gtcrn):,}")
print(f"MinGTCRN params: {count_parameters(min_gtcrn):,}")
print(f"Difference:      {count_parameters(min_gtcrn) - count_parameters(gtcrn):,}")
print()
print(f"MPNet params:    {count_parameters(mpnet):,}")
print(f"MinMPNet params: {count_parameters(min_mpnet):,}")
print(f"Difference:      {count_parameters(min_mpnet) - count_parameters(mpnet):,}")

GTCRN params:    23,669
MinGTCRN params: 15,669
Difference:      -8,000

MPNet params:    2,263,372
MinMPNet params: 1,333,580
Difference:      -929,792


In [4]:
print("nn.GRU modules in MinGTCRN:")
for model in [min_gtcrn, min_mpnet]:
    for name, module in model.named_modules():
        if isinstance(module, nn.GRU):
            print(f"  {name}: {module.__class__.__name__}")
            break
else:
    print("No nn.GRU modules found in MinGTCRN")

nn.GRU modules in MinGTCRN:
No nn.GRU modules found in MinGTCRN


In [None]:
n_runs = 500

# (B, F, T)
noisy_amp = torch.randn(1, 201, 100)
noisy_pha = torch.randn(1, 201, 100)

mpnet.eval()
min_mpnet.eval()

# Warmup
for _ in range(5):
    with torch.no_grad():
        _ = mpnet(noisy_amp, noisy_pha)
        _ = min_mpnet(noisy_amp, noisy_pha)

start = time.perf_counter()
for _ in trange(n_runs):
    with torch.no_grad():
        _ = mpnet(noisy_amp, noisy_pha)
mpnet_time = (time.perf_counter() - start) / n_runs * 1000

start = time.perf_counter()
for _ in trange(n_runs):
    with torch.no_grad():
        _ = min_mpnet(noisy_amp, noisy_pha)
min_mpnet_time = (time.perf_counter() - start) / n_runs * 1000

print(f"MPNet:    {mpnet_time:.2f} ms/forward")
print(f"MinMPNet: {min_mpnet_time:.2f} ms/forward")
print(f"Speedup:  {mpnet_time / min_mpnet_time:.2f}x")

  0%|          | 0/500 [00:00<?, ?it/s]