In [2]:
import torch
import time
from utils.utils import generate_mask, load_model, writeDACFile, sample_top_n
from dataloader.dataset import CustomDACDataset
from utils.utils import interpolate_vectors, breakpoints, breakpoints_classseq

import os
import yaml

from DACTransformer.RopeCondDACTransformer import RopeCondDACTransformer

import numpy as np
import matplotlib.pyplot as plt

import dac
import soundfile as sf
import IPython.display as ipd

%load_ext autoreload
%autoreload 2


## Parameters

In [3]:
## params ##########################################################
# set this to whatever you called the experiment in the top of your params.yaml file.
experiment_name= "mini_test_01" #"smalltest_dataset" 
# probably don't change this is the default, set in the params.yaml file.
checkpoint_dir = 'runs' + '/' + experiment_name  

cptnum =  100 # (the checkpoint number must be in the checkpoint directory)
SAVEWAV=False
DEVICE='cuda' #######''cuda'
gendur=20 #how many seconds you wnat your output sound to be
topn=1 # sample from the top n logits
device = DEVICE
###########################################################################
#  Choose a breakpoint sequence (and/or make one yourself) ...
###########################################################################
morphname='conditioning'  ###   (choose from breakpoint sets defined below)
#morphname='sweep'  ###   (choose from breakpoint sets defined below)

### Read Paramfile and get class list

In [4]:
#any config.yaml files used for training are copied to the checkpoint directory as "params.yaml"
paramfile = checkpoint_dir + '/' +  'params.yaml' 
print(f"will use paramfile= {paramfile}") 
# Load YAML file
with open(paramfile, 'r') as file:
    params = yaml.safe_load(file)

# Create an instance of the dataset
data_dir = params['data_dir']
data_frames =  params['data_frames']
dataset = CustomDACDataset(data_dir=data_dir, metadata_excel=data_frames, transforms=None)
FEATURES = params['FEATURES']
#For your reference:
#Print the list of all classes
classes=dataset.get_class_list()
print(f'classes={classes}')
print(f' ------- One hot vectors for classes ----------')
for i in range(len(classes)):
    print(f' {classes[i]} : \t{dataset.onehot(classes[i])}')

will use paramfile= runs/mini_test_01/params.yaml
classes=['Sax Baritone', 'Sax Soprano', 'Sax Tenor']
 ------- One hot vectors for classes ----------
 Sax Baritone : 	tensor([1., 0., 0.])
 Sax Soprano : 	tensor([0., 1., 0.])
 Sax Tenor : 	tensor([0., 0., 1.])


Morph over a vectors in vsequence lineary for (noramlized) time steps vtimes. Create your sequence:

### <font color='blue'> Derived parameters  </font>

In [5]:
# Get parameters from yaml file and derive any necessary
######################################################

inference_steps=86*gendur  #86 frames per second
    
TransformerClass =  globals().get(params['TransformerClass'])  
print(f"using TransformerClass = {params['TransformerClass']}") 
print(f' and TransformerClass is class object {TransformerClass}')

cond_size = 8 # num_classes + num params - not a FREE parameter!

### embed_size = params['tblock_input_size'] -cond_size # 240 #32  # embed_size must be divisible by num_heads and by num tokens
embed_size = params['model_size'] # 240 #32  # embed_size must be divisible by num_heads and by num tokens
print(f'embed_size is {embed_size}')


fnamebase='out' + '.e' + str(embed_size) + '.l' + str(params['num_layers']) + '.h' + str(params['num_heads']) + '_chkpt_' + str(cptnum).zfill(4) 
checkpoint_path = checkpoint_dir + '/' +  fnamebase  + '.pth' 

# for saving sound 
outdir=checkpoint_dir

print(f'checkpoint_path = {checkpoint_path}, fnamebase = {fnamebase}' )

using TransformerClass = RopeCondDACTransformer
 and TransformerClass is class object <class 'DACTransformer.RopeCondDACTransformer.RopeCondDACTransformer'>
embed_size is 64
checkpoint_path = runs/mini_test_01/out.e64.l6.h8_chkpt_0100.pth, fnamebase = out.e64.l6.h8_chkpt_0100


In [6]:
if DEVICE == 'cuda' :
    torch.cuda.device_count()
    torch.cuda.get_device_properties(0).total_memory/1e9

    device = torch.device(DEVICE) # if the docker was started with --gpus all, then can choose here with cuda:0 (or cpu)
    torch.cuda.device_count()
    print(f'memeory on cuda 0 is  {torch.cuda.get_device_properties(0).total_memory/1e9}')
else :
    device=DEVICE
device

memeory on cuda 0 is  6.21903872


device(type='cuda')

# The inference method

In [7]:
import torch
import random
import threading
import mido

# Shared class index variable, updated by MIDI thread
current_class_idx = 0  # Default to the first class (update via MIDI)

def midi_listener(port_name, num_classes, n_params=6):
    """
    Listens for MIDI input and updates `current_class_idx` dynamically.
    
    Args:
        port_name (str): The name of the MIDI port to listen on.
        num_classes (int): Number of available classes.
    """
    global current_class_idx
    global params
    params = torch.zeros(n_params).to(device)
    params = params + .5

    try:
        with mido.open_input(port_name) as port:
            print(f"Listening for MIDI on {port_name}...")
            for msg in port:
                if msg.type == 'note_on':  # Use note number to pick a class
                    new_class = msg.note % num_classes  # Map MIDI notes to classes
                    current_class_idx = new_class
                    print(f"Updated class index: {current_class_idx}")
                
                elif msg.type == 'control_change':  # Use CC for a different mapping
                    new_param = msg.value
                    param_idx = msg.control%21
                    params[param_idx] = new_param/127
                    print(params)
                    print(f"Updated param {param_idx} via CC: {new_param}")

    except Exception as e:
        print(f"MIDI Error: {e}")

# midi_listener(mido.get_input_names()[1], len(classes))


In [8]:

def generate_midi_controlled_cond(inference_steps, classes, param_count=1):
    """
    Generates a conditioning sequence where the class index is controlled by live MIDI input.
    
    Args:
        inference_steps (int): Number of time steps (frames) for inference.
        classes (list): List of class names (one-hot encoded).
        param_count (int): Number of continuous parameters (random walk).
    
    Returns:
        cond: A Tensor of shape (1, inference_steps, cond_size).
    """
    num_classes = len(classes)
    cond_size = num_classes + param_count

    # Prepare a buffer for (inference_steps, cond_size)
    cond = torch.zeros(inference_steps, cond_size)
    
    for i, p in enumerate(params):
        cond[0, num_classes + i] = p
    global current_class_idx
    cond[0, current_class_idx] = 1.0
    for t in range(1, inference_steps):
        # Update class index from MIDI input (global variable)
        #global current_class_idx

        # Copy previous step
        cond[t] = cond[t-1]

        # Reset the class portion to zero, then set the one-hot class
        cond[t, :num_classes] = 0.0
        cond[t, current_class_idx] = 1.0
        
        # Interpolate cond with previous cond using torch.lerp
        alpha = 0.9
        # prev_cond.to(device)
        # cond.to(device)
        # cond[t] = torch.lerp(prev_cond, cond[t], alpha)
    # Add batch dimension => shape (1, T, cond_size)
    # print("BBBBBBBBBBBBBBBBBBB", current_class_idx, cond)
    # cond[:, 0] = 1.0
    return cond.unsqueeze(0)

In [9]:
#Load the stored model
model, _, Ti_context_length, vocab_size, num_codebooks, cond_size = load_model(checkpoint_path,  TransformerClass, DEVICE)

print(f'Mode loaded, context_length (Ti_context_length) = {Ti_context_length}')
# Count the number of parameters
num_params = sum(p.numel() for p in model.parameters())
print(f'Total number of parameters: {num_params}')

model.to(device);

 ------------- input_size is embed_size + cond_size = 73
 ------------- embed_dim (64) must be divisible by num_heads (8)
Setting up MultiEmbedding with vocab_size= 1024, embed_size= 64, num_codebooks= 4
Setting up RotaryPositionalEmbedding with embed_size= 64, max_len= 575
Mode loaded, context_length (Ti_context_length) = 258
Total number of parameters: 561024


In [10]:
dacmodel_path = dac.utils.download(model_type="44khz") 
print(f'The DAC decoder is in {dacmodel_path}')
with torch.no_grad():
    dacmodel = dac.DAC.load(dacmodel_path)

    dacmodel.to(device); #wanna see the model? remove the semicolon
    dacmodel.eval();  # need to be "in eval mode" in order to set the number of quantizers

The DAC decoder is in /home/angel/.cache/descript/dac/weights_44khz_8kbps_0.0.1.pth


  WeightNorm.apply(module, name, dim)


In [11]:
 
def inference(model, inference_cond, Ti_context_length, vocab_size, num_tokens, inference_steps, topn, fname, prev_tokens) :
    model.eval()
    with torch.no_grad():
        # print(Ti_context_length)
        mask = generate_mask(Ti_context_length, Ti_context_length).to(device)

        # pseudocódigo
        prev_context = prev_tokens[:, -Ti_context_length:, :]
        # y en la próxima llamada:
        input_data = prev_context

        #input_data = torch.randint(0, vocab_size, (1, Ti_context_length, num_tokens)).to(device)  # Smaller context window for inference
        # print("Inference cond: ", inference_cond.shape)
        #Extend the first conditional vector to cover the "input" which is of length Ti_context_length
        inference_cond = torch.cat([inference_cond[:, :1, :].repeat(1, Ti_context_length, 1), inference_cond], dim=1)
        # print("Inference cond: ", inference_cond.shape)
        predictions = []

        
        t0 = time.time()
        for i in range(inference_steps):  # 
            # print(input_data.shape, inference_cond.shape, mask.shape)

            if cond_size == 0:
                output = model(input_data, None, mask) # step through 
            else :
                # print(f'input_data.shape = {input_data.shape}, inference_cond[:, i:Ti_context_length+i, :].shape = {inference_cond[:, i:Ti_context_length+i, :].shape}, mask.shape = {mask.shape}')
                # print(inference_cond[:, i:Ti_context_length+i, :])
                output = model(input_data, inference_cond[:, i:Ti_context_length+i, :], mask) # step through

            # This takes the last vector of the sequence (the new predicted token stack) so has size(b,steps,4,1024)
            # This it takes the max across the last dimension (scores for each element of the vocabulary (for each of the 4 tokens))
            # .max returns a duple of tensors, the first are the max vals (one for each token) and the second are the
            #        indices in the range of the vocabulary size. 
            # THAT IS, the selected "best" tokens (one for each codebook) are taken independently
            ########################### next_token = output[:, -1, :, :].max(-1)[1]  # Greedy decoding for simplicity

            next_token = sample_top_n(output[:, -1, :, :],topn) # topn=1 would be the same as max in the comment line above    
            predictions.append(next_token)
            input_data = torch.cat([input_data, next_token.unsqueeze(1)], dim=1)[:, 1:]  # Slide window

        t1 = time.time()
        inf_time = t1-t0

        dacseq = torch.cat(predictions, dim=0).unsqueeze(0).transpose(1, 2)

        return dacseq



In [12]:
print(Ti_context_length)

258


In [13]:
import torch
import sounddevice as sd
import queue
import threading
import time

# Initialize audio queue
try:
    del audio_queue
except:
    pass
audio_queue = queue.Queue()

# Audio callback function
def audio_callback(outdata, frames, time_info, status):
    if status:
        print(f"Status: {status}")
    try:
        chunk = audio_queue.get_nowait()
        print(chunk)
        print(frames)
        outdata[:len(chunk)] = chunk.reshape(-1, 1)
        if len(chunk) < frames:
            outdata[len(chunk):] = 0
    except queue.Empty:
        outdata.fill(0)

# Print available midi ports
port = mido.get_input_names()[1]
# Start MIDI listener thread
midi_thread = threading.Thread(target=midi_listener, args=(port, len(classes)))  # Replace "YourMIDIport" with actual port
# midi_thread.daemon = True
midi_thread.start()

dur = 3
blocksize = int(44100 * dur)
inference_steps = int(86 * dur)
Ti_context_length = int(86 * dur)
# device="cpu"
# Function to generate audio using the model
def generate_audio():
    with torch.no_grad():
        prev_tokens = torch.randint(0, vocab_size, (1, Ti_context_length, num_codebooks)).to(device)

        while True:
            print(f'current_class = {classes[current_class_idx]}')
            cond = generate_midi_controlled_cond(
                inference_steps, 
                classes, 
                param_count=6).to(device)  
            print(cond[0][0]) 
            # print("AAAAAAAAAAAAAAAAAAA",cond)
            # Model inference
            # print(cond.isnan().sum(), cond.isinf().sum())
            codeseq = inference(model, cond, Ti_context_length, vocab_size, num_codebooks, inference_steps, topn, "", prev_tokens)
            # print(f'codeseq shape = {codeseq.shape}')
            prev_tokens = codeseq.reshape(1, -1, num_codebooks)
            dac_file = dac.DACFile(
                codes=codeseq.cpu(),
                chunk_length=codeseq.shape[2],
                original_length=int(codeseq.shape[2] * 512),
                input_db=torch.tensor(-20),
                channels=1,
                sample_rate=44100,
                padding=True,
                dac_version='1.0.0'
            )
            audio_signal = dacmodel.decompress(dac_file)
            audio_data = audio_signal.samples.view(-1).numpy()
            
            # Enqueue audio data
            # Slice the audio_data into 4096 frames blocks
            audio_queue.put(audio_data[:blocksize])
            # for i in range(0, len(audio_data), blocksize):
            #     chunk = audio_data[i:i+blocksize]
            #     audio_queue.put(chunk)

# Start audio stream
samplerate = 44100
stream = sd.OutputStream(
    samplerate=samplerate,
    channels=1,
    blocksize=blocksize,
    callback=audio_callback
)
stream.start()

# Start audio generation in a separate thread
audio_thread = threading.Thread(target=generate_audio)
# audio_thread.daemon = True
audio_thread.start()


import tkinter as tk
import time
import threading



def create_gui():
    """
    Create a Tkinter window that shows `current_class_idx` and the values
    of `params` in real time.
    """
    root = tk.Tk()
    root.title("Real-Time Visualization")

    # Label to show current_class_idx
    idx_label = tk.Label(root, text=f"Current class: {classes[current_class_idx]}", font=("Arial", 14))
    idx_label.pack(pady=5)

    # A set of labels to display each element of `params`.
    # Alternatively, you could replace labels with Tkinter Scales, or embed a
    # matplotlib chart for a histogram.
    param_labels = []
    for i, param_val in enumerate(params):
        lbl = tk.Label(root, text=f"Param {i + 1}: {param_val:.2f}", font=("Arial", 12))
        lbl.pack()
        param_labels.append(lbl)

    def update_gui():
        """
        Refresh the window with new values from the global variables.
        This method re-schedules itself every 100ms.
        """
        # Update the current_class_idx text
        idx_label.config(text=f"Current class: {classes[current_class_idx]}")

        # Update each param label
        for i, val in enumerate(params):
            param_labels[i].config(text=f"{FEATURES[i]}: {val:.2f}")

        # Schedule the next update in 100ms
        root.after(100, update_gui)

    # Kick off the periodic GUI update
    update_gui()

    # Start the Tkinter event loop. This will block until the window is closed.
    root.mainloop()

create_gui()

# Keep the main thread alive while audio is playing
try:
    while True:
        time.sleep(1)
except KeyboardInterrupt:
    print("Stopping...")
    stream.stop()
    stream.close()
    # audio_thread.join()
    # print("Audio stream stopped")
    # midi_thread.join()
    


Listening for MIDI on Launchkey Mini MK3:Launchkey Mini MK3 Launchkey Mi 20:0...
current_class = Sax Baritone
tensor([0.5000, 1.0000, 0.5000, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 1 via CC: 127
tensor([1.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
       device='cuda:0')
tensor([0.5000, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 2 via CC: 15
tensor([0.5512, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 70


Exception in thread Thread-7 (generate_audio):
Traceback (most recent call last):
  File "/home/angel/anaconda3/envs/dacformer/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/home/angel/anaconda3/envs/dacformer/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 766, in run_closure
    _threading_Thread_run(self)
  File "/home/angel/anaconda3/envs/dacformer/lib/python3.10/threading.py", line 953, in run
    self._target(*self._args, **self._kwargs)
  File "/tmp/ipykernel_133888/17961283.py", line 69, in generate_audio
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.


tensor([0.5433, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 69
tensor([0.5354, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 68
tensor([0.5276, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 67
tensor([0.5197, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 66
tensor([0.5118, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 65
tensor([0.5039, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 64
tensor([0.4961, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 63
tensor([0.4882, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 62
tensor([0.4803, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 61
tensor([0.4724, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 v

tensor([0.7717, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 98
tensor([0.7795, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 99
tensor([0.7874, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 100
tensor([0.7953, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 101
tensor([0.8031, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 102
tensor([0.8110, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 103
tensor([0.8189, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 104
tensor([0.8268, 1.0000, 0.1181, 0.5000, 0.5000, 0.5000], device='cuda:0')
Updated param 0 via CC: 105


In [26]:
Ti_context_length = 86*3
inference_steps = 86*3
prev_tokens = torch.randint(0, vocab_size, (1, Ti_context_length, num_codebooks)).to(device)

print(f'current_class = {classes[current_class_idx]}')
cond = generate_midi_controlled_cond(
    inference_steps, 
    classes, 
    param_count=6).to(device)  
codeseq = inference(model, cond, Ti_context_length, vocab_size, num_codebooks, inference_steps, topn, "", prev_tokens)

dac_file = dac.DACFile(
                codes=codeseq.cpu(),
                chunk_length=codeseq.shape[2],
                original_length=int(codeseq.shape[2] * 512),
                input_db=torch.tensor(-20),
                channels=1,
                sample_rate=44100,
                padding=True,
                dac_version='1.0.0'
            )
audio_signal = dacmodel.decompress(dac_file)
audio_data = audio_signal.samples.view(-1).numpy()

import IPython.display as ipd

ipd.Audio(audio_data, rate=44100)

current_class = Sax Baritone


In [15]:
import mido

def discover_midi_ccs():
    # 1. Print available MIDI input ports
    print("Available input ports:")
    for port_name in mido.get_input_names():
        print(f"  {port_name}")

    # 2. Open one of these ports (replace with your device name, or pick the first found port)
    port_name = mido.get_input_names()[1]  # You can hardcode or prompt the user
    with mido.open_input(port_name) as inport:
        print(f"\nListening on: {port_name}")
        print("Move your knobs to see the messages... (Press Ctrl+C to stop)")

        for msg in inport:
            # Print every incoming message
            print(msg)
discover_midi_ccs()

Available input ports:
  Midi Through:Midi Through Port-0 14:0
  Launchkey Mini MK3:Launchkey Mini MK3 Launchkey Mi 20:0
  Launchkey Mini MK3:Launchkey Mini MK3 Launchkey Mi 20:1

Listening on: Launchkey Mini MK3:Launchkey Mini MK3 Launchkey Mi 20:0
Move your knobs to see the messages... (Press Ctrl+C to stop)
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0
clock time=0

KeyboardInterrupt: 