In [1]:
import os
import sys
# We add the backend folder to our path to do imports from the main module
PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), '..', '..', '..'))
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

from backend.ml_model.helper import load_latest_checkpoint
import os
from backend.ml_model.encoding import EncodingConfig
from backend.ml_model.helper import get_device, load_latest_checkpoint
from backend.ml_model.dataloader import OnTheFlyMidiDataset
import glob
from torch.utils.data import DataLoader
import torch
from bertviz import head_view, model_view

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def midi_to_note_name(midi_pitch: int) -> str:
    """Converts MIDI pitch (60) to Note Name (C4)."""
    if midi_pitch is None:
        return ""
    note_names = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
    octave = (midi_pitch // 12) - 1
    note = note_names[midi_pitch % 12]
    return f"{note}{octave}"

def token_to_string(input_ids):
    # Assuming input_ids is a tensor, e.g., torch.tensor([[...]])
    # We grab the first batch item
    raw_token_ids = input_ids[0].tolist()

    formatted_tokens = []

    for t_id in raw_token_ids:
        # 1. Decode the integer using the Config
        info = EncodingConfig.token_to_info(t_id)

        instr = info['instrument']

        # 2. Format string based on token type
        if instr == 'Special':
            # Result: "Time", "Begin", "End"
            label = f"<{info['tag']}>"

        elif instr == 'Microtiming':
            # Result: "Micro:+1"
            sign = "+" if info['delta'] > 0 else ""
            label = f"Micro:{sign}{info['delta']}"

        elif instr == 'Drums':
            # Result: "Drums:36" (Keeping raw pitch for drums is standard)
            # You could also use midi_to_note_name(info['pitch']) if preferred
            label = f"Drums:{info['pitch']}"

        elif instr == 'Unknown':
            label = f"Unk:{t_id}"

        else:
            # Melodic Instruments (Piano, Bass, etc.)
            # Result: "Piano:C4" or "Bass:F#2"
            note_name = midi_to_note_name(info['pitch'])
            label = f"{instr}:{note_name}"

        formatted_tokens.append(label)

    return formatted_tokens

In [3]:
runs_dir = r'C:\Users\mbrun\Documents\University_Branche\Project\webserver\backend\ml_model\runs'
dataset_path = r'C:\Users\mbrun\Documents\University_Branche\Project\webserver\backend\ml_model\lpd_5'

model1 = 'Phi-3-33M-head-dim-32-GQA-ratio-16-8-ebs-64-lr-1e-3-epochs-100'
model2 = 'Phi-3-33M-head-dim-64-GQA-ratio-1-ebs-64-lr-1e-3-epochs-100'

device = 'xpu'

In [4]:
# This is how the model is loaded internally (with weights)
# from transformers import Phi3Config, Phi3ForCausalLM
# config = Phi3Config(**checkpoint['config'])
# model = Phi3ForCausalLM(config)
model = load_latest_checkpoint(os.path.join(runs_dir, model1), model_only=True)

# Get dataloader
sample_files = glob.glob(os.path.join(dataset_path, 'lpd_5_cleansed/*/*/*/*/*.npz'))
dataset = OnTheFlyMidiDataset(sample_files, n_modulations=0, chunk_size=2048, warmup_steps=128)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0, pin_memory=False)

model.eval()
model.to(device)

# Get a single batch from your dataloader
# We only need one example to visualize what the model is doing
data_iter = iter(dataloader)
batch = next(data_iter)

# Assuming your dataloader returns a dict or tuple. Adjust to match your structure.
# If it returns a tuple: input_ids = batch[0]
input_ids = batch[0].to(device)

# Limit sequence length for visualization
# 32 tokens is usually enough to see the pattern
seq_len = 32
input_ids = input_ids[0:1, :seq_len]

# Run Forward Pass with output_attentions=True
with torch.no_grad():
    outputs = model(input_ids, output_attentions=True)
    attentions = outputs.attentions  # This is a tuple of tensors (one per layer)

# Convert IDs to Token Strings
tokens = token_to_string(input_ids)

# Visualize
print(f"Rendering Attention Map of {model1}")
head_view(attentions, tokens)
model_view(attentions, tokens)

Loading checkpoint: C:\Users\mbrun\Documents\University_Branche\Project\webserver\backend\ml_model\runs\Phi-3-33M-head-dim-32-GQA-ratio-16-8-ebs-64-lr-1e-3-epochs-100\checkpoint_epoch_96.ph
Rendering Attention Map of Phi-3-33M-head-dim-32-GQA-ratio-16-8-ebs-64-lr-1e-3-epochs-100


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [5]:
# This is how the model is loaded internally (with weights)
# from transformers import Phi3Config, Phi3ForCausalLM
# config = Phi3Config(**checkpoint['config'])
# model = Phi3ForCausalLM(config)
model = load_latest_checkpoint(os.path.join(runs_dir, model2), model_only=True)

# Get dataloader
sample_files = glob.glob(os.path.join(dataset_path, 'lpd_5_cleansed/*/*/*/*/*.npz'))
dataset = OnTheFlyMidiDataset(sample_files, n_modulations=0, chunk_size=2048, warmup_steps=128)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0, pin_memory=False)

model.eval()
model.to(device)

# Get a single batch from your dataloader
# We only need one example to visualize what the model is doing
data_iter = iter(dataloader)
batch = next(data_iter)

# Assuming your dataloader returns a dict or tuple. Adjust to match your structure.
# If it returns a tuple: input_ids = batch[0]
input_ids = batch[0].to(device)

# Limit sequence length for visualization
# 32 tokens is usually enough to see the pattern
seq_len = 32
input_ids = input_ids[0:1, :seq_len]

# Run Forward Pass with output_attentions=True
with torch.no_grad():
    outputs = model(input_ids, output_attentions=True)
    attentions = outputs.attentions  # This is a tuple of tensors (one per layer)

# Convert IDs to Token Strings
tokens = token_to_string(input_ids)

# Visualize
print(f"Rendering Attention Map of {model1}")
head_view(attentions, tokens)
model_view(attentions, tokens)

Loading checkpoint: C:\Users\mbrun\Documents\University_Branche\Project\webserver\backend\ml_model\runs\Phi-3-33M-head-dim-64-GQA-ratio-1-ebs-64-lr-1e-3-epochs-100\checkpoint_epoch_93.ph
Rendering Attention Map of Phi-3-33M-head-dim-32-GQA-ratio-16-8-ebs-64-lr-1e-3-epochs-100


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>