In [2]:
import numpy as np

SR = 44100
NOISE = np.random.rand(SR * 10).astype(np.float32)
t = np.arange(SR * 10).astype(np.float32) / SR
SINE = np.sin(2.0 * np.pi * t * 440.0)
EXTREME = np.array([-1.0, 0.0, 1.0], dtype=np.float32)

In [67]:
def quantize_unsigned_pcm(x: np.ndarray, n_bits: int) -> np.ndarray:
    if x.dtype != np.float32:
        raise ValueError("x must be float32")
    if x.min() < -1 or x.max() > 1:
        raise ValueError("x must be in range [-1, 1]")
    if n_bits < 1 or n_bits > 64:
        raise ValueError("n_bits must be between 1 and 64")

    # Map from [-1, 1] to [0, 1)
    x_normalized = (x + 1) / 2

    # Scale by 2^n_bits (not 2^n_bits - 1) to maintain MSB invariance
    scale = 2 ** n_bits
    x_scaled = x_normalized * scale

    # Use floor to convert to integer (not round, to maintain MSB invariance)
    x_floored = np.floor(x_scaled)

    # Clamp to valid range [0, 2^n_bits - 1]
    max_val = (2 ** n_bits) - 1
    x_clamped = np.clip(x_floored, 0, max_val)

    return x_clamped.astype(np.uint64)


def msb(x: np.ndarray, orig_n_bits: int, n_bits: int) -> np.ndarray:
    if x.dtype != np.uint64:
        raise ValueError("x must be uint64")
    return (x >> (orig_n_bits - n_bits)) & ((1 << n_bits) - 1)

def lsb(x: np.ndarray, n_bits: int) -> np.ndarray:
    if x.dtype != np.uint64:
        raise ValueError("x must be uint64")
    return x & ((1 << n_bits) - 1)

for x in [NOISE, SINE, EXTREME]:
    x_24b = quantize_unsigned_pcm(x, 24)
    x_16b = quantize_unsigned_pcm(x, 16)
    x_13b = quantize_unsigned_pcm(x, 13)
    x_8b = quantize_unsigned_pcm(x, 8)

    #x_24b_16msb = (x_24b >> (24 - 16)) & 0xFFFF
    x_24b_16msb = msb(x_24b, 24, 16)
    assert np.array_equal(x_24b_16msb, x_16b)

    #x_24b_13msb = (x_24b >> (24 - 13)) & 0x1FFF
    x_24b_13msb = msb(x_24b, 24, 13)
    assert np.array_equal(x_24b_13msb, x_13b)

    #x_16b_8msb = (x_16b >> (16 - 8)) & 0xFF
    x_16b_8msb = msb(x_16b, 16, 8)
    assert np.array_equal(x_16b_8msb, x_8b)

In [12]:

import torch
def quantize_unsigned_pcm_torch(x: torch.tensor, n_bits: int) -> torch.tensor:
    if x.dtype != torch.float32:
        raise ValueError("x must be float32")
    if x.min() < -1 or x.max() > 1:
        raise ValueError("x must be in range [-1, 1]")
    if n_bits < 1 or n_bits > 64:
        raise ValueError("n_bits must be between 1 and 64")

    # Map from [-1, 1] to [0, 1)
    x_normalized = (x + 1) / 2

    # Scale by 2^n_bits (not 2^n_bits - 1) to maintain MSB invariance
    scale = 2 ** n_bits
    x_scaled = x_normalized * scale

    # Use floor to convert to integer (not round, to maintain MSB invariance)
    x_floored = torch.floor(x_scaled)

    # Clamp to valid range [0, 2^n_bits - 1]
    max_val = (2 ** n_bits) - 1
    x_clamped = torch.clip(x_floored, 0, max_val)

    return x_clamped.to(torch.int64)



def msb_torch(x: torch.tensor, orig_n_bits: int, n_bits: int) -> torch.tensor:
    # if x.dtype != torch.uint64:
    #     raise ValueError("x must be uint64")
    return (x >> (orig_n_bits - n_bits)) & ((1 << n_bits) - 1)

def lsb_torch(x: torch.tensor, n_bits: int) -> torch.tensor:
    # if x.dtype != torch.uint64:
    #     raise ValueError("x must be uint64")
    return x & ((1 << n_bits) - 1)

for x in [torch.from_numpy(NOISE), torch.from_numpy(SINE), torch.from_numpy(EXTREME)]:
    x_24b = quantize_unsigned_pcm_torch(x, 24)
    x_16b = quantize_unsigned_pcm_torch(x, 16)
    x_13b = quantize_unsigned_pcm_torch(x, 13)
    x_8b = quantize_unsigned_pcm_torch(x, 8)

    #x_24b_16msb = (x_24b >> (24 - 16)) & 0xFFFF
    x_24b_16msb = msb_torch(x_24b, 24, 16)
    assert torch.equal(x_24b_16msb, x_16b)

    #x_24b_13msb = (x_24b >> (24 - 13)) & 0x1FFF
    x_24b_13msb = msb_torch(x_24b, 24, 13)
    assert torch.equal(x_24b_13msb, x_13b)

    #x_16b_8msb = (x_16b >> (16 - 8)) & 0xFF
    x_16b_8msb = msb_torch(x_16b, 16, 8)
    assert torch.equal(x_16b_8msb, x_8b)

    # get middle 8 bits of 24b, should to 8 lsb of 16b
    x_24b_8mid = lsb_torch(msb_torch(x_24b, 24, 16), 8)
    assert torch.equal(x_24b_8mid, lsb_torch(x_16b, 8))

In [10]:
((1 << 8) - 1) == 0xFF

True

In [None]:
from torch.utils.data import Dataset
import torch
import torchaudio
import os
import random
import json
from tqdm import tqdm
class MonoWavChunkDataset(Dataset):
    def __init__(self, data_dir, chunk_size=4096, sample_rate=44100, bit_split=False, epoch_expansion_factor=10, only_lower_bits=False, stereo_interleave=False):
        self.files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.wav')]
        self.chunk_size = chunk_size
        self.sample_rate = sample_rate
        self.bit_split = bit_split
        self.epoch_expansion_factor = epoch_expansion_factor
        self.only_lower_bits = only_lower_bits
        self.stereo_interleave = stereo_interleave

        if len(self.files) == 0:
            raise ValueError("files is empty")
        print(f"MonoWavChunkDataset: {len(self.files)} files, chunk_size={self.chunk_size}")
        pth = 'musdbstereo_lengths_train.json' if 'train' in data_dir else 'musdbstereo_lengths_valid.json' if 'valid' in data_dir else 'musdbstereo_lengths.json'
        lengths = json.load(open(pth, 'r'))
        for ix, f in enumerate(tqdm(self.files)):
            self.files[ix] = (f, lengths[os.path.basename(f)])  # (path, num_samples)
        self.files = self.files * self.epoch_expansion_factor
        random.shuffle(self.files)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path, file_length = self.files[idx]
        print(path)
        # randomly sample a chunk of chunk_size from the file
        chunk_size = self.chunk_size + 1
        # offset = torch.randint(0, max(1, file_length - chunk_size), (1,)).item()
        offset = 0
        wav, sr = torchaudio.load(path, normalize=True, frame_offset=offset, num_frames=chunk_size, backend="soundfile")
        print(wav.shape, sr, wav.dtype, wav.min(), wav.max())
        # wav, sr = torchaudio.load(path, normalize=False)
        if wav.dtype != torch.int16:
            wav = linear_encode(wav, bits=16)
        else:
            wav = wav.long() + 32768  # 
        print(wav.dtype, wav.min(), wav.max())
        # randomly sample left or right channel
        if self.stereo_interleave:
            # put left, then right or right then left
            interleaved = torch.zeros(wav.shape[1] * 2, dtype=wav.dtype)
            if torch.rand(1).item() < 0.5:
                interleaved[:wav.shape[1]] = wav[0]
                interleaved[wav.shape[1]:] = wav[1]
            else:
                interleaved[:wav.shape[1]] = wav[1]
                interleaved[wav.shape[1]:] = wav[0]
            wav = interleaved
        else:
            if torch.rand(1).item() < 0.5:
                wav = wav[1]  # take right channel only
            else:
                wav = wav[0]  # take left channel only
        # if bit_split is set, split each 16-bit value into two 8-bit values representing the high and low bytes
        if self.bit_split:
            splits = self.bit_split if type(self.bit_split) is int else 2
            if splits == 2:
                high_bits = (wav >> 8) & 0xFF
                low_bits = wav & 0xFF
                # add 2^8 to the low bits to distinguish them from high bits
                low_bits += 256
                # interleave high and low bits
                wav = torch.stack([high_bits, low_bits], dim=1).view(-1)
                assert torch.all(wav[0] == high_bits[0])
                assert torch.all(wav[1] == low_bits[0])

                # split back into separate channels and depths for testing
                high_bits_recon = wav[::2]
                low_bits_recon = wav[1::2] - 256
                channel1_hb = high_bits_recon[:high_bits_recon.shape[0]//2]
                channel2_hb = high_bits_recon[high_bits_recon.shape[0]//2:]
                channel1_lb = low_bits_recon[:low_bits_recon.shape[0]//2]
                channel2_lb = low_bits_recon[low_bits_recon.shape[0]//2:]
                return channel1_hb, channel1_lb, channel2_hb, channel2_lb 

            elif splits == 4:
                byte3 = (wav >> 12) & 0x0F
                byte2 = (wav >> 8) & 0x0F
                byte1 = (wav >> 4) & 0x0F
                byte0 = wav & 0x0F
                # add 2^4, 2^8, 2^12 to distinguish them
                byte2 += 16
                byte1 += 32
                byte0 += 48
                wav = torch.stack([byte3, byte2, byte1, byte0], dim=1).view(-1)
                assert torch.all(wav[0] == byte3[0])
                assert torch.all(wav[1] == byte2[0])
                assert torch.all(wav[2] == byte1[0])
                assert torch.all(wav[3] == byte0[0])
            elif splits == 3:
                # first highest 8 bits, then next 4 bits, then lowest 4 bits
                byte2 = (wav >> 8) & 0xFF
                byte1 = (wav >> 4) & 0x0F
                byte0 = wav & 0x0F
                byte1 += 256
                byte0 += 272
                wav = torch.stack([byte2, byte1, byte0], dim=1).view(-1)
                assert torch.all(wav[0] == byte2[0])
                assert torch.all(wav[1] == byte1[0])
                assert torch.all(wav[2] == byte0[0])
        elif self.only_lower_bits:
            wav = wav & 0xFF  # keep only the lower 8 bits


        if sr != self.sample_rate:
            wav = torchaudio.functional.resample(wav, sr, self.sample_rate)
        if len(wav) < self.chunk_size+1:
            wav = torch.nn.functional.pad(wav, (0, self.chunk_size+1 - len(wav)), mode='constant', value=q_zero(bits=16))
        chunk = wav
        tokens = chunk.long()
        seq_len = self.chunk_size
        if self.bit_split:
            seq_len *= self.bit_split if type(self.bit_split) is int else 2
        if self.stereo_interleave:
            seq_len *= 2
        seq_len = min(seq_len + 1, len(tokens))
        tokens = tokens[:seq_len]
        input_tokens = tokens[:-1]
        target_tokens = tokens[1:]
        return input_tokens, target_tokens
    

def minmax_scale(tensor, range_min=0, range_max=1):
    """
    Min-max scaling to [0, 1].
    """
    min_val = torch.amin(tensor, dim=(1, 2), keepdim=True)
    max_val = torch.amax(tensor, dim=(1, 2), keepdim=True)
    return range_min + (range_max - range_min) * (tensor - min_val) / (max_val - min_val + 1e-6)

def quantize(samples, bits=8, epsilon=0.01):
    """
    Linearly quantize a signal in [0, 1] to a signal in [0, q_levels - 1].
    """
    q_levels = 1 << bits
    samples *= q_levels - epsilon
    samples += epsilon / 2
    return samples.long()

def dequantize(samples, bits=8):
    """
    Dequantize a signal in [0, q_levels - 1].
    """
    q_levels = 1 << bits
    return samples.float() / (q_levels / 2) - 1

def mu_law_encode(audio, bits=8):
    """
    Perform mu-law companding transformation.
    """
    mu = torch.tensor((1 << bits) - 1)

    # Audio must be min-max scaled between -1 and 1
    audio = minmax_scale(audio, range_min=-1, range_max=1)

    # Perform mu-law companding transformation.
    numerator = torch.log1p(mu * torch.abs(audio + 1e-8))
    denominator = torch.log1p(mu)
    encoded = torch.sign(audio) * (numerator / denominator)

    # Shift signal to [0, 1]
    encoded = (encoded + 1) / 2

    # Quantize signal to the specified number of levels.
    return quantize(encoded, bits=bits)

def mu_law_decode(encoded, bits=8):
    """
    Perform inverse mu-law transformation.
    """
    mu = (1 << bits) - 1
    # Invert the quantization
    x = dequantize(encoded, bits=bits)

    # Invert the mu-law transformation
    x = torch.sign(x) * ((1 + mu)**(torch.abs(x)) - 1) / mu

    # Returned values in range [-1, 1]
    return x

def linear_encode(samples, bits=8):
    """
    Perform scaling and linear quantization.
    """
    samples = samples.clone()
    samples = minmax_scale(samples)
    return quantize(samples, bits=bits)

def linear_decode(samples, bits=8):
    """
    Invert the linear quantization.
    """
    return dequantize(samples, bits=bits)

def q_zero(bits=8):
    """
    The quantized level of the 0.0 value.
    """
    return 1 << (bits - 1)


In [58]:
# set seeds
seed = 42
torch.manual_seed(seed)
random.seed(seed)


dataset = MonoWavChunkDataset(data_dir='/graft3/datasets/pnlong/lnac/sashimi/data/musdb18stereo/train', chunk_size=512, bit_split=2, stereo_interleave=True)

MonoWavChunkDataset: 600 files, chunk_size=512


100%|██████████| 600/600 [00:00<00:00, 611860.54it/s]


In [59]:
seed = 42
torch.manual_seed(seed)
random.seed(seed)

item = dataset[0]

/graft3/datasets/pnlong/lnac/sashimi/data/musdb18stereo/train/Secret Mountains - High Horse.0.wav
torch.Size([2, 513]) 44100 torch.int16 tensor(-9225, dtype=torch.int16) tensor(7436, dtype=torch.int16)
torch.int64 tensor(23543) tensor(40204)




In [60]:
item

(tensor([134, 135, 136, 138, 139, 140, 142, 143, 144, 143, 144, 145, 145, 147,
         148, 148, 148, 149, 149, 150, 150, 151, 152, 153, 154, 154, 154, 155,
         155, 155, 156, 155, 155, 155, 154, 152, 152, 153, 151, 150, 150, 149,
         149, 149, 148, 147, 146, 146, 145, 143, 142, 142, 141, 139, 137, 137,
         137, 136, 134, 133, 132, 131, 132, 131, 130, 130, 131, 131, 131, 131,
         133, 134, 135, 137, 139, 141, 143, 145, 146, 146, 147, 147, 146, 146,
         147, 146, 146, 148, 148, 147, 148, 149, 150, 149, 149, 150, 150, 151,
         151, 150, 149, 149, 149, 148, 147, 146, 145, 144, 143, 140, 139, 139,
         138, 136, 136, 136, 136, 136, 136, 136, 136, 136, 136, 137, 137, 138,
         138, 139, 138, 138, 137, 136, 136, 135, 134, 134, 134, 134, 133, 132,
         132, 132, 132, 132, 133, 134, 135, 134, 134, 134, 134, 133, 132, 131,
         130, 130, 129, 129, 129, 128, 127, 127, 127, 126, 124, 124, 125, 124,
         123, 123, 123, 124, 123, 123, 124, 123, 123

In [44]:
# now load it directly as float and see if we get the same thing
seed = 42
torch.manual_seed(seed)
random.seed(seed)
pth = "/graft3/datasets/pnlong/lnac/sashimi/data/musdb18stereo/train/Secret Mountains - High Horse.0.wav"
ref_wav, sr = torchaudio.load(pth, frame_offset=0, num_frames=513, backend="soundfile")



In [46]:
ref_wav_24b = quantize_unsigned_pcm(ref_wav.numpy(), 24)
ref_wav_16b = quantize_unsigned_pcm(ref_wav.numpy(), 16)
ref_wav_8b = quantize_unsigned_pcm(ref_wav.numpy(), 8)

In [71]:
item[3] - ref_wav_16b_8lsb[0]

  item[3] - ref_wav_16b_8lsb[0]


tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [68]:
ref_wav_24b_16msb = msb(ref_wav_24b, 24, 16)
ref_wav_24b_8msb = msb(ref_wav_24b, 24, 8)
ref_wav_16b_8msb = msb(ref_wav_16b, 16, 8)
ref_wav_16b_8lsb = lsb(ref_wav_16b, 8)

In [69]:
ref_wav_16b_8lsb

array([[173,  55, 145, ..., 148, 133, 172],
       [ 53, 202,  16, ..., 226, 251,  29]], shape=(2, 513), dtype=uint64)

In [1]:
import torch

In [2]:
x = torch.load("/graft3/datasets/haven/sketch2music/0000000010000.pt")

In [None]:
import stable_a

torch.Size([64, 1937])