In [1]:
import random
import jax
import jax.numpy as np
import librosa
from pathlib import Path

class Dataloader:
    def __init__(self, dataset_path, sample_rate):
        self.sample_rate = sample_rate
        self.list_of_files = list(Path(dataset_path).rglob("*.wav"))
    
    def load_and_process(self, file_path):
        # Load file and check sample rate
        y, sr = librosa.load(file_path, sr=None)
        if sr != self.sample_rate:
            y = librosa.resample(y, orig_sr=sr, target_sr=self.sample_rate)
        
        # Generate a random start position for segment extraction
        random_start = random.randint(0, max(len(y) - self.sample_rate, 0))
        y_segment = y[random_start : random_start + self.sample_rate]
        
        return np.array(y_segment)
    
    def get_batch(self, ids):
        # Load each file individually, outside of JAX operations
        batch = [self.load_and_process(self.list_of_files[id]) for id in ids]
        
        # Convert list of arrays to a JAX array with batch dimension
        batch = jax.numpy.stack(batch)
        
        # Add extra dimension for compatibility with batch processing
        batch = jax.numpy.expand_dims(batch, -2)
        
        return batch

    def __len__(self):
        return len(self.list_of_files)

In [None]:

%load_ext tensorboard
import tensorflow as tf
import datetime
%matplotlib inline
from IPython.display import clear_output
import equinox as eqx
import optax
import jax
import matplotlib.pyplot as plt
from encodec import EncodecModel

@eqx.filter_jit
@eqx.filter_value_and_grad(has_aux=True)
def calculate_loss(model, x):
    y = jax.vmap(model)(x)

    # MSE
    loss = jax.numpy.linalg.norm((x - y), axis=-1)
    # loss = jax.numpy.mean((x - y) ** 2, axis=-1)

    loss = jax.numpy.mean(loss)
    # commit_loss = jax.numpy.mean((z_e - jax.lax.stop_gradient(z_q)) ** 2)
    # loss += commit_loss
    return loss, y

@eqx.filter_jit
def make_step(model, optimizer, opt_state, x):
    global i
    (losses, y), grads = calculate_loss(model, x)
    # print_per_layer(model)
    updates, opt_state = optimizer.update(grads, opt_state, eqx.filter(model, eqx.is_array))
    model = eqx.apply_updates(model, updates)
    # print_per_layer(model, "Model", i)
    # print_per_layer(updates, "Updates", i)
    # print_per_layer(grads, "Grads", i)
    # i+=1
    return losses, y, model, opt_state

# Return truncated weight array
def trunc_init(weight: jax.Array, key: jax.random.PRNGKey) -> jax.Array:
    print(weight.shape)
    out, in_, three = weight.shape
    
    # Calculate fan_in and fan_out
    fan_in = in_
    fan_out = out
    stddev = jax.numpy.sqrt(2.0 / (fan_in)) * 0.001
    
    return stddev * jax.random.normal(key, shape=(out, in_, three))

def init_conv_weight(model, init_fn, key):
    is_conv = lambda x: isinstance(x, (eqx.nn.Conv1d, eqx.nn.ConvTranspose1d))
    get_weights = lambda m: [x.weight
                            for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv)
                            if is_conv(x)]
    weights = get_weights(model)
    new_weights = [init_fn(weight, subkey)
                    for weight, subkey in zip(weights, jax.random.split(key, len(weights)))]
    new_model = eqx.tree_at(get_weights, model, new_weights)
    return new_model

def print_per_layer(model, tree_name, batch_idx):
    # Function to identify Conv1d and ConvTranspose1d layers
    is_conv = lambda x: isinstance(x, (eqx.nn.Conv1d, eqx.nn.ConvTranspose1d))
    
    # Extract weights from layers identified as Conv1d or ConvTranspose1d
    get_weights = lambda m: [x.weight
                             for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv)
                             if is_conv(x)]
    
    weights = get_weights(model)

    # Log weights for each layer separately in TensorBoard
    with train_summary_writer.as_default():
        for i, weight in enumerate(weights):
            flat_weight = jax.numpy.reshape(weight, -1)
            tf.summary.histogram(f'{tree_name}/layer_{i}_weights', flat_weight, step=batch_idx)
    
    # print("Logged histograms for each layer’s weights to TensorBoard.")


log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_summary_writer = tf.summary.create_file_writer(log_dir)

batch_size = 64
epochs = 100
learning_rate = 3e-2

key = jax.random.PRNGKey(2)

grab, key = jax.random.split(key, 2)

model = EncodecModel(activation=jax.nn.silu, n_res_layers=0, key=grab)

# sprint_per_layer(model)
model = init_conv_weight(model, trunc_init, grab)
# print_per_layer(model)
optimizer = optax.chain(
    optax.adam(learning_rate, b1=0.5, b2=0.9),
    # optax.add_decayed_weights(0.01)
)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

dataset = Dataloader("/home/tugdual/JAXTTS/ResNet/wav", sample_rate=24000)


plt.ion()  # Turn on interactive mode
fig, ax = plt.subplots(figsize=(10, 4))
line_input, = ax.step([], [], label='Input', where='mid')
line_output, = ax.step([], [], label='Output', where='mid')
ax.set_title('Model Input and Output')
ax.set_xlabel('Sample Index')
ax.set_ylabel('Amplitude')
ax.legend()
ax.grid()

for epoch in range(epochs):
    grab, key = jax.random.split(key, 2)
    perm = jax.random.permutation(grab, len(dataset))

    for batch_idx, batch in enumerate(range(0, len(dataset), batch_size)):
        ids = perm[batch : batch + batch_size]
        batch = dataset.get_batch(ids)
        losses, y, model, opt_state = make_step(model, optimizer, opt_state, batch)
        x = batch[0, 0]
        y = y[0, 0]
        # break
        with train_summary_writer.as_default():
            tf.summary.scalar('loss', losses, step=batch_idx)
        if batch_idx % 16 == 0:
            line_input.set_xdata(np.arange(len(x)))
            line_input.set_ydata(x)
            line_output.set_xdata(np.arange(len(y)))
            line_output.set_ydata(y)

            # Set the limits
            ax.set_xlim(0, len(x))
            ax.set_ylim(-1, 1)

            # Redraw the figure
            clear_output(wait=True)
            print(f"Losses: {losses}")
            print(y)
            plt.draw()
            plt.pause(0.1)
            display(fig)  # Redisplay the figure after updating


Losses: 6.066703796386719
[ 0.00205232  0.00102228 -0.00179449 ...  0.00335476  0.00053784
  0.00472914]


<Figure size 640x480 with 0 Axes>

In [None]:
import jax
from encodec import EncodecModel

key = jax.random.PRNGKey(1)
model = EncodecModel(activation=jax.nn.elu,key=key)
print(model)
# x = jax.random.normal(key, shape=(1, 1, 24000))
x = jax.numpy.ones((1, 1, 24000))
print(x.shape)
y = jax.vmap(model)(x)

EncodecModel(
  encoder=Encoder(
    B_layers=[
      EncoderLayer(
        resblocks=[
          ResBlock(
            conv1=Conv1d(
              num_spatial_dims=1,
              weight=f32[32,32,3],
              bias=f32[32,1],
              in_channels=32,
              out_channels=32,
              kernel_size=(3,),
              stride=(1,),
              padding='SAME',
              dilation=(1,),
              groups=1,
              use_bias=True,
              padding_mode='ZEROS'
            ),
            conv2=Conv1d(
              num_spatial_dims=1,
              weight=f32[32,32,3],
              bias=f32[32,1],
              in_channels=32,
              out_channels=32,
              kernel_size=(3,),
              stride=(1,),
              padding='SAME',
              dilation=(1,),
              groups=1,
              use_bias=True,
              padding_mode='ZEROS'
            ),
            activation=<wrapped function relu>,
            norm=<class 'equin