In [None]:
import torch
import audiobox_aesthetics.model.wavlm as wavlm_mod   # this is where GradMultiply lives

# JIT-friendly function: forward pass is pure Tensor math
@torch.jit.script
def _grad_multiply_jit(x: torch.Tensor, scale: float):
    return x                      # identity (we only care about inference!)

# shim object that mimics the old API (has an .apply method)
class _GradMultiplyShim:
    @staticmethod
    def apply(x, scale):
        return _grad_multiply_jit(x, scale)

# 🚨 IMPORTANT: patch *before* you create / trace the model
wavlm_mod.GradMultiply = _GradMultiplyShim


In [None]:
class AesWrapper(torch.nn.Module):
    def __init__(self, core):
        super().__init__()
        self.core = core
    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        batch = {"wav": wav.unsqueeze(1)}          # [B,T] → [B,1,T]
        out = self.core(batch)
        return torch.stack([out["CE"], out["CU"], out["PC"], out["PQ"]], dim=-1)

In [None]:
model = AesMultiOutput.from_pretrained("facebook/audiobox-aesthetics").eval().to("cpu")
# … (wrapper definition) …
example = torch.randn(1, 16_000 * 10)
ts = torch.jit.trace(AesWrapper(model), example, strict=False)
ts.save("audiobox_aesthetics_ts.pt")          # ← now works


In [None]:
ts = torch.jit.load("audiobox_aesthetics_ts.pt").to("cuda")


In [None]:
x  = torch.randn(1, 16000 * 5, device="cuda")     # 5-second dummy audio
ts(x)

In [None]:
device = torch.device("cuda:0")          # or "cuda:1", …
ts = torch.jit.load("audiobox_aesthetics_ts.pt", map_location=device)
ts = ts.eval()                           # just a habit; avoids dropout, etc.


In [None]:
import torch

device = "cuda:0"                           # pick your GPU

ts = torch.jit.load("audiobox_aesthetics_ts.pt", map_location="cpu")
ts = ts.to(device)                          # <── this line fixes the error
ts = ts.eval()

x = torch.randn(1, 16_000 * 5, device=device)   # 5-s dummy audio
scores = ts(x)                                # OK: all tensors on cuda:0
print(scores)                                 # tensor([[CE, CU, PC, PQ]])


# Trace for CPU

In [58]:
import torch
from audiobox_aesthetics.model.aes import AesMultiOutput
import audiobox_aesthetics.model.wavlm as wavlm_mod

# --------------------------------------------------------------------------- #
# 1.  Gra​dMultiply shim  (unchanged)
@torch.jit.script
def _grad_multiply_jit(x: torch.Tensor, scale: float):
    return x
class _GradMultiplyShim:
    @staticmethod
    def apply(x, scale):
        return _grad_multiply_jit(x, scale)
wavlm_mod.GradMultiply = _GradMultiplyShim
# --------------------------------------------------------------------------- #

device = "cpu"
core   = AesMultiOutput.from_pretrained("facebook/audiobox-aesthetics") \
                       .to(device).eval()

# --------------------------------------------------------------------------- #
# 2.  Wrapper that *bakes in* de-zscoring
class AesWrapper(torch.nn.Module):
    def __init__(self, core):
        super().__init__()
        self.core = core

        # ---- pull the stats out of the original model ----------------------
        order = ["CE", "CU", "PC", "PQ"]
        means = [core.target_transform[s]["mean"] for s in order]
        stds  = [core.target_transform[s]["std"]  for s in order]

        # register as BUFFERS so TorchScript treats them as constants
        self.register_buffer("mean", torch.tensor(means, dtype=torch.float32))
        self.register_buffer("std",  torch.tensor(stds,  dtype=torch.float32))

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        # input wav shape  [B, T]  (16 kHz mono, -1..1)
        batch  = {"wav": wav.unsqueeze(1)}          # → [B,1,T]
        raw    = self.core(batch)                   # dict of z-scores
        z      = torch.stack([raw["CE"], raw["CU"], raw["PC"], raw["PQ"]], -1)
        return z * self.std + self.mean             # denorm inside the graph
# --------------------------------------------------------------------------- #

wrapper = AesWrapper(core).to(device)

# ----- trace or script ------------------------------------------------------
example = torch.randn(1, 16_000 * 10, device=device)   # 10-s dummy
ts = torch.jit.trace(wrapper, example, strict=False)    # or  torch.jit.script(wrapper)
ts.eval()
ts = torch.jit.freeze(ts)                              # optional but nice

ts.save("audiobox_aesthetics_ts_cpu.pt")
print("✅ TorchScript bundle with baked-in means/stds saved.")


✅ TorchScript bundle with baked-in means/stds saved.


# TEST INFERENCE

In [59]:
import torch
ts = torch.jit.load("audiobox_aesthetics_ts_cpu.pt")
ts = ts.eval()


In [64]:
with torch.inference_mode():
    for _ in range(3):           # call repeatedly – no crash
        x  = torch.randn(1, 16000*10,device="cpu")
        print(ts(x))             # tensor([[CE, CU, PC, PQ]])
ts(x).shape

tensor([[2.3310, 4.8411, 2.0928, 4.7827]])
tensor([[2.2362, 4.7066, 2.0517, 4.7879]])
tensor([[2.1967, 4.3483, 2.1122, 4.6560]])


torch.Size([1, 4])

# MAKE ONNX

In [95]:
import torch.onnx as onnx

dummy = torch.randn(1, 16000*10,device="cpu")

# 3. Export
onnx.export(
    ts,                      # the ScriptModule
    dummy,                   # example input
    "audiobox_aesthetics.onnx",
    export_params=True,      # store learned weights in the file
    opset_version=18,        # or whichever ONNX version you target
    input_names=["audio"],
    output_names=["scores"],
    dynamic_axes={
        "audio":  {1: "n_samples"},   # allow arbitrary length waveforms
    },
)
print("✅  ONNX export complete → audiobox_aesthetics.onnx")

✅  ONNX export complete → audiobox_aesthetics.onnx


In [104]:
import onnxruntime as rt
dummy = torch.randn(1,16000*10,device="cpu")
sess = rt.InferenceSession("audiobox_aesthetics.onnx")
out = sess.run(None, {"audio": dummy.numpy()})[0]
print(torch.tensor(out)) # should match TorchScript output

[1;31m2025-06-22 01:07:36.556064347 [E:onnxruntime:, sequential_executor.cc:572 ExecuteKernel] Non-zero status code returned while running Where node. Name:'/Where' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 499 by 768
[m


RuntimeException: [ONNXRuntimeError] : 6 : RUNTIME_EXCEPTION : Non-zero status code returned while running Where node. Name:'/Where' Status Message: /onnxruntime_src/onnxruntime/core/providers/cpu/math/element_wise_ops.h:540 void onnxruntime::BroadcastIterator::Init(ptrdiff_t, ptrdiff_t) axis == 1 || axis == largest was false. Attempting to broadcast an axis by a dimension other than 1. 499 by 768


In [75]:
import torch
model = torch.jit.load('audiobox_aesthetics_ts_cpu.pt')
print(model.graph)  # Examine operations

graph(%self : __torch__.___torch_mangle_3677.AesWrapper,
      %wav.1 : Tensor):
  %3713 : Tensor = prim::Constant[value= 5.0686  5.7363  3.1859  6.5750 [ CPUFloatType{4} ]]() # :0:0
  %3711 : Tensor = prim::Constant[value= 1.9303  1.7567  1.8664  1.5147 [ CPUFloatType{4} ]]() # :0:0
  %3700 : Tensor = prim::Constant[value={-0.0261768}]() # :0:0
  %3699 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3691 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3690 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3686 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3685 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3677 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3676 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3672 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3671 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3663 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %3662 : Tensor = prim::Constant[value=<Tensor>]() # :0:0
  %

In [85]:
import torchaudio
device = "cpu"
# 3. Inference helper ---------------------------------------------------------
@torch.inference_mode()
def score(file_or_tensor):
    if isinstance(file_or_tensor, str):
        wav, sr = torchaudio.load(file_or_tensor)
        print(wav.shape, sr)
        if wav.size(0) > 1:                        # already mono? skip
            wav = wav.mean(dim=0, keepdim=True)    # [1, T]
        if sr != 16_000:                 # model expects 16 kHz
            wav = torchaudio.functional.resample(wav, sr, 16_000)
        #start = 0 #16_000 * 30
        #end = start + 16_000*20
        #wav = wav[:,start:end]
        #wav.to(device)
        print(wav.shape, 16_000)
    else:
        wav = file_or_tensor             # already a tensor

    scores = ts(wav.to(device))          # [B, T] → [B, 4] on GPU
    return scores
    return dict(zip(["CE", "CU", "PC", "PQ"], scores.squeeze(0).tolist()))

# --------------------------------------------------------------------------- #
#x = torch.randn(1, 16000*10, device="cuda:0")
x = "carbomb.mp3"
xscore = score(x)
xscore


torch.Size([2, 13174319]) 44100
torch.Size([1, 4779799]) 16000


KeyboardInterrupt: 

In [50]:
for stat in ["CE","CU","PC","PQ"]:
    norm_score = (xscore[stat] * model.target_transform[stat]['std']) +  model.target_transform[stat]['mean']
    print(stat, norm_score)

CE 6.492822108002901
CU 7.496735419051647
PC 6.63395175412178
PQ 7.312182024731636


In [51]:
wav, sr = torchaudio.load("carbomb.mp3")
# print(wav.shape, sr)
# if wav.size(0) > 1:                        # already mono? skip
#     wav = wav.mean(dim=0, keepdim=True)    # [1, T]
# if sr != 16_000:                 # model expects 16 kHz
#     wav = torchaudio.functional.resample(wav, sr, 16_000)
# sr = 16_000
# wav = wav[:,:16_000*20]
# wav.shape

In [52]:
from audiobox_aesthetics import infer as aes_infer
aes_predictor = aes_infer.initialize_predictor()
transcription = aes_predictor.forward([{"path": wav, "sample_rate": sr}])[0]
transcription

{'CE': 6.288222312927246,
 'CU': 7.438724517822266,
 'PC': 6.5343337059021,
 'PQ': 7.253620624542236}

In [28]:
# Inspect the transform for a single metric, e.g., 'PQ'
pq_transform = model.target_transform['PQ']
pq_mean = pq_transform['mean']
pq_std = pq_transform['std']

print(f"PQ Mean: {pq_mean}")
print(f"PQ Std: {pq_std}")

# You can store these values to use with your TorchScript output
INVERSE_NORM_CONSTANTS = {
    axis: {
        'mean': model.target_transform[axis]['mean'],
        'std': model.target_transform[axis]['std']
    }
    for axis in ["CE", "CU", "PC", "PQ"]
}

PQ Mean: 6.57505
PQ Std: 1.51466


In [31]:

# 3) 10-s hop / window split, like infer.py (optional but recommended)
def split_to_windows(wav, sr=16_000, win=10, hop=10):
    step = hop * sr
    out  = []
    for st in range(0, wav.shape[-1], step):
        seg = wav[..., st:st + win * sr]
        if seg.shape[-1] < win * sr:          # right-pad with zeros
            seg = torch.nn.functional.pad(seg, (0, win*sr - seg.shape[-1]))
        out.append(seg)
    return torch.stack(out)

#wav = torch.randn(1, 16000*10*3, device="cuda:1")    # dummy 30-s clip
windows = split_to_windows(wav)                      # (N,1,T) on cuda
pred_z  = ts(windows.squeeze(1))                     # (N,4)  z-scores

# 4) inverse-normalize and length-average
pred = {}
for i, axis in enumerate(AXES):
    mean, std = stats[axis]["mean"], stats[axis]["std"]
    raw = pred_z[:, i] * std + mean                  # 1-10 scale
    pred[axis] = raw.mean().item()                   # same weight everywhere
print(json.dumps(pred, indent=2))

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__/___torch_mangle_1402.py", line 101, in forward
    _32 = torch.slice(_31, -1, 1536, 2304)
    _33 = torch.slice(_31, -1, 768, 1536)
    _34 = torch.view(torch.slice(_31, -1, 0, 768), [_8, _19, _30])
          ~~~~~~~~~~ <--- HERE
    q = torch.transpose(_34, 0, 1)
    _35 = torch.view(_33, [torch.size(_33, 0), _19, _30])

Traceback of TorchScript, original code (most recent call last):
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/functional.py(6271): multi_head_attention_forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/audiobox_aesthetics/model/wavlm.py(567): forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1730): _slow_forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1751): _call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1740): _wrapped_call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/audiobox_aesthetics/model/wavlm.py(1571): forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1730): _slow_forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1751): _call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1740): _wrapped_call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/audiobox_aesthetics/model/wavlm.py(1449): extract_features
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/audiobox_aesthetics/model/wavlm.py(1415): forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1730): _slow_forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1751): _call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1740): _wrapped_call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/audiobox_aesthetics/model/wavlm.py(1212): extract_features
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/audiobox_aesthetics/model/aes.py(162): forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1730): _slow_forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1751): _call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1740): _wrapped_call_impl
/tmp/ipykernel_466784/474887763.py(26): forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1730): _slow_forward
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1751): _call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/nn/modules/module.py(1740): _wrapped_call_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/jit/_trace.py(1276): trace_module
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/jit/_trace.py(696): _trace_impl
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/torch/jit/_trace.py(1000): trace
/tmp/ipykernel_466784/474887763.py(33): <module>
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3579): run_code
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3519): run_ast_nodes
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3336): run_cell_async
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/IPython/core/async_helpers.py(128): _pseudo_sync_runner
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3132): _run_cell
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3077): run_cell
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/zmqshell.py(549): run_cell
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/ipkernel.py(449): do_execute
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/kernelbase.py(778): execute_request
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/ipkernel.py(362): execute_request
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/kernelbase.py(437): dispatch_shell
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/kernelbase.py(534): process_one
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/kernelbase.py(545): dispatch_queue
/usr/lib/python3.10/asyncio/events.py(80): _run
/usr/lib/python3.10/asyncio/base_events.py(1909): _run_once
/usr/lib/python3.10/asyncio/base_events.py(603): run_forever
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/tornado/platform/asyncio.py(205): start
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel/kernelapp.py(739): start
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/traitlets/config/application.py(1075): launch_instance
/media/dadatron/squirrel/sa/venvs/sad310/lib/python3.10/site-packages/ipykernel_launcher.py(18): <module>
/usr/lib/python3.10/runpy.py(86): _run_code
/usr/lib/python3.10/runpy.py(196): _run_module_as_main
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.


# torch.jit.script

In [None]:
import torch
import torch.nn.utils as nn_utils
import audiobox_aesthetics.model.wavlm as wavlm_mod
from audiobox_aesthetics.model.aes import AesMultiOutput

# ================================================================= #
# FINAL SET OF PATCHES AND WRAPPERS
# ================================================================= #

# PATCH 1: Fix the unscriptable FloorDiv operator (`//=`)
@torch.jit.script
def _bucket_relative_position_fixed(relative_positions: torch.Tensor, bidirectional: bool = True):
    # ... (full function code as before)
    num_buckets = 320; max_distance = 128; relative_buckets = torch.zeros_like(relative_positions, dtype=torch.long)
    if bidirectional:
        num_buckets = num_buckets // 2; relative_buckets += (relative_positions > 0).to(torch.long) * num_buckets
        relative_positions = torch.abs(relative_positions)
    else:
        relative_positions = -torch.min(relative_positions, torch.zeros_like(relative_positions))
    max_exact = num_buckets // 2; is_small = relative_positions < max_exact
    val_if_large = max_exact + (torch.log(relative_positions.float() / max_exact) / torch.log(torch.tensor(max_distance / max_exact)) * (num_buckets - max_exact)).to(torch.long)
    val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
    relative_buckets += torch.where(is_small, relative_positions, val_if_large)
    return relative_buckets
print("Applying Patch 1: Fixing 'wavlm_mod._bucket_relative_position'...")
wavlm_mod._bucket_relative_position = _bucket_relative_position_fixed

# DEFINITIVE WRAPPER for PATCH 4
# This wrapper ensures the output of the MHA module is clean for the tracer.
class MHA_Tracer_Wrapper(torch.nn.Module):
    def __init__(self, mha_module):
        super().__init__()
        self.mha_module = mha_module

    def forward(self, q, k, v):
        # Call the original module which might return (tensor, None)
        output_tuple = self.mha_module(q, k, v)
        # We only return the first element, which is the attention output tensor.
        # This guarantees a JIT-traceable output.
        return output_tuple[0]

# ------------------------------------------------------------------ #
# Load model and create top-level wrapper
model = AesMultiOutput.from_pretrained("facebook/audiobox-aesthetics").eval()
class AesWrapper(torch.nn.Module):
    def __init__(self, core):
        super().__init__(); self.core = core
    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        out = self.core({"wav": wav.unsqueeze(1)})
        return torch.stack([out[k] for k in ("CE", "CU", "PC", "PQ")], -1)
wrapper = AesWrapper(model).cpu().eval()

# PATCH 2: Find and delete ALL 'precision' attributes
print("\nApplying Patch 2: Searching for and deleting all 'precision' attributes...")
for name, module in wrapper.named_modules():
    if hasattr(module, 'precision'):
        print(f"  - Found and deleted 'precision' in submodule: {name}")
        delattr(module, 'precision')

# PATCH 3: Remove all weight_norm hooks
print("\nApplying Patch 3: Removing all 'weight_norm' hooks...")
for name, module in wrapper.named_modules():
    try:
        nn_utils.remove_weight_norm(module)
    except ValueError:
        continue
print("  - Finished removing weight_norm hooks.")

# HYBRID PATCH 4: Pre-trace the problematic MultiheadAttention using the clean wrapper
print("\nApplying Hybrid Patch 4: Pre-tracing 'MultiheadAttention' submodules via a wrapper...")
for name, module in list(wrapper.named_modules()):
    if isinstance(module, wavlm_mod.MultiheadAttention):
        print(f"  - Wrapping and tracing submodule: {name}")
        embed_dim = module.embed_dim; kdim = module.kdim if module.kdim is not None else embed_dim
        vdim = module.vdim if module.vdim is not None else embed_dim
        q = torch.randn(10, 1, embed_dim); k = torch.randn(10, 1, kdim); v = torch.randn(10, 1, vdim)
        
        # 1. Create an instance of our clean wrapper
        tracer_wrapper = MHA_Tracer_Wrapper(module.eval())
        
        # 2. Trace the WRAPPER, not the original module
        traced_submodule = torch.jit.trace(tracer_wrapper, (q, k, v))
        
        # 3. Replace the original module with the new, traced, clean-output module
        path = name.split('.'); parent = wrapper
        for p in path[:-1]:
            parent = getattr(parent, p)
        setattr(parent, path[-1], traced_submodule)
        print(f"  - Replaced '{name}' with its traced version.")

# ================================================================= #
# FINAL SCRIPTING ATTEMPT
# ================================================================= #
@torch.jit.script
def _gradmul(x: torch.Tensor, scale: float): return x
class _GradMulShim:
    @staticmethod
    def apply(x, scale): return _gradmul(x, scale)
wavlm_mod.GradMultiply = _GradMulShim
model.apply(lambda m: setattr(m, "grad_mult", 0.0) if hasattr(m, "grad_mult") else None)

print("\nScripting model...")
ts_portable = torch.jit.script(wrapper)
ts_portable.eval()

ts_portable = torch.jit.freeze(ts_portable)
ts_portable.save("audiobox_aesthetics_ts_portable.pt")
print("\n\n✅✅✅ VICTORY! ✅✅✅\nPortable TorchScript saved successfully as audiobox_aesthetics_ts_portable.pt")

In [None]:
import torch, audiobox_aesthetics.model.wavlm as wavlm_mod
from audiobox_aesthetics.model.aes import AesMultiOutput

# ------------------------------------------------------------------ #
# 0.  TorchScript-friendly GradMultiply (identity in forward pass)
@torch.jit.script
def _gradmul(x: torch.Tensor, scale: float):   return x
class _GradMulShim:
    @staticmethod
    def apply(x, scale):   return _gradmul(x, scale)
wavlm_mod.GradMultiply = _GradMulShim
# ------------------------------------------------------------------ #

# 1.  Load the model from the HF hub (**CPU**)
model = AesMultiOutput.from_pretrained("facebook/audiobox-aesthetics").eval()

# 2.  Disable gradient scaling (training-only feature)
model.apply(lambda m: setattr(m, "grad_mult", 0.0) if hasattr(m, "grad_mult") else None)

# 3.  Thin wrapper → tensor in / tensor out
class AesWrapper(torch.nn.Module):
    def __init__(self, core):
        super().__init__(); self.core = core
    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        out = self.core({"wav": wav.unsqueeze(1)})           # [B,T] → [B,1,T]
        return torch.stack([out[k] for k in ("CE","CU","PC","PQ")], -1)
wrapper = AesWrapper(model).cpu()
wrapper.eval()                          # ← important

# 4.  Trace on CPU  ➜  freeze  ➜  save
example = torch.randn(1, 16_000 * 10)          # 10-s dummy audio
ts = torch.jit.trace(wrapper, example, strict=False)
ts.eval()                               # usually redundant but explicit

ts = torch.jit.freeze(ts)                      # removes autograd metadata
ts.save("audiobox_aesthetics_ts.pt")
print("✅  portable TorchScript saved as audiobox_aesthetics_ts.pt")


In [None]:
ts = torch.jit.load("audiobox_aesthetics_ts.pt").to("cuda:1").eval()
x  = torch.randn(1, 16_000*10, device="cuda:1")
print(ts(x))                                # runs every call
