# Data Exploration and Setup
This notebook demonstrates how to use the project's source code for data loading and exploration.
We are using a hybrid approach: core logic is in `src/` and exploration is in `notebooks/`.

In [None]:
import torch
import os
import sys

# Add src to path
sys.path.append(os.path.abspath('..'))

from ...models.musicbert import MusicBERT, MusicBERTConfig
from ...data.musicbert import MusicBERTDataset

# Configuration (Must match training config)
MAX_SEQ_LEN = 1024
# Optimized vocab sizes for OctupleMIDI (TimeSig, Tempo, Bar, Pos, Instr, Pitch, Dur, Vel)
# +4 for special tokens (PAD, MASK, CLS, EOS)
VOCAB_SIZES = [258, 53, 260, 132, 133, 132, 132, 36]

config = MusicBERTConfig(
    vocab_sizes=VOCAB_SIZES,
    element_embedding_size=512,
    hidden_size=512,
    num_layers=4,
    num_attention_heads=8,
    ffn_inner_hidden_size=2048,
    dropout=0.1,
    max_position_embeddings=MAX_SEQ_LEN,
    max_seq_len=MAX_SEQ_LEN
)

# Initialize Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MusicBERT(config).to(device)

# Load Checkpoint
CHECKPOINT_PATH = '../checkpoints/musicbert_latest.pth'

if os.path.exists(CHECKPOINT_PATH):
    print(f"Loading checkpoint from {CHECKPOINT_PATH}")
    state_dict = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(state_dict)
    model.eval()
    print("Model loaded and set to evaluation mode.")
else:
    print(f"Checkpoint not found at {CHECKPOINT_PATH}")

Loading checkpoint from ../checkpoints/musicbert_latest.pth
Model loaded and set to evaluation mode.


In [2]:
# Test with dummy input
batch_size = 1
seq_len = 128 # Arbitrary sequence length for testing

# Generate random input (batch_size, seq_len, 8)
# We need to respect vocab sizes for each attribute
dummy_input = torch.zeros((batch_size, seq_len, 8), dtype=torch.long).to(device)
for i in range(8):
    dummy_input[:, :, i] = torch.randint(0, VOCAB_SIZES[i], (batch_size, seq_len)).to(device)

# Generate dummy attention mask (batch_size, seq_len)
# False indicates valid tokens, True indicates padding (ignored)
dummy_mask = torch.zeros((batch_size, seq_len), dtype=torch.bool).to(device)

print(f"Testing model with input shape: {dummy_input.shape}")

try:
    with torch.no_grad():
        logits_list = model(dummy_input, attention_mask=dummy_mask)
    
    print("Forward pass successful!")
    print(f"Number of output heads: {len(logits_list)}")
    print(f"Output shape per head: {logits_list[0].shape}") # Should be (batch_size, seq_len, vocab_size)
    
except Exception as e:
    print(f"Error during forward pass: {e}")

Testing model with input shape: torch.Size([1, 128, 8])


  output = torch._nested_tensor_from_mask(


Forward pass successful!
Number of output heads: 8
Output shape per head: torch.Size([1, 128, 258])


In [3]:
import numpy as np

# Use the Dataset class to handle loading and masking
DATA_PATH = '../data/processed/'

if os.path.exists(DATA_PATH):
    # Initialize Dataset
    dataset = MusicBERTDataset(DATA_PATH, max_seq_len=MAX_SEQ_LEN, vocab_sizes=VOCAB_SIZES)
    print(f"Dataset size: {len(dataset)}")
    
    if len(dataset) > 0:
        # Get a sample (e.g., index 0)
        sample_idx = 0
        sample = dataset[sample_idx]
        
        input_ids = sample['input_ids'].unsqueeze(0).to(device) # (1, seq_len, 8)
        labels = sample['labels'].unsqueeze(0).to(device)       # (1, seq_len, 8)
        attention_mask = sample['attention_mask'].unsqueeze(0).to(device) # (1, seq_len)
        
        print(f"Input shape: {input_ids.shape}")
        
        # Run Model
        model.eval()
        with torch.no_grad():
            logits_list = model(input_ids, attention_mask=attention_mask)
            
        # Get Predictions
        # logits_list is list of 8 tensors, each (batch, seq_len, vocab_size)
        predictions = []
        for i in range(8):
            # Get max over vocab dimension
            pred_tokens = torch.argmax(logits_list[i], dim=-1) # (batch, seq_len)
            predictions.append(pred_tokens)
            
        # Stack predictions to get (batch, seq_len, 8)
        predicted_tokens = torch.stack(predictions, dim=2).squeeze(0).cpu().numpy()
        
        # Convert inputs and labels to numpy for display
        input_ids_np = input_ids.squeeze(0).cpu().numpy()
        labels_np = labels.squeeze(0).cpu().numpy()
        
        # Print Results
        print("\n--- Results (First 20 tokens) ---")
        print(f"{'Index':<6} | {'Original':<30} | {'Masked Input':<30} | {'Predicted':<30}")
        print("-" * 105)
        
        for i in range(min(20, len(input_ids_np))):
            orig = str(labels_np[i])
            masked = str(input_ids_np[i])
            pred = str(predicted_tokens[i])
            
            # Highlight if masked (input != label)
            is_masked = not np.array_equal(labels_np[i], input_ids_np[i])
            marker = "*" if is_masked else " "
            
            print(f"{i:<6} | {orig:<30} | {marker} {masked:<28} | {pred:<30}")
    else:
        print("Dataset is empty.")
else:
    print(f"Data path {DATA_PATH} not found.")

Dataset size: 909
Input shape: torch.Size([1, 1024, 8])



--- Results (First 20 tokens) ---
Index  | Original                       | Masked Input                   | Predicted                     
---------------------------------------------------------------------------------------------------------
0      | [2 2 2 2 2 2 2 2]              |   [2 2 2 2 2 2 2 2]            | [2 2 2 2 2 2 2 2]             
1      | [ 5 29  4 70 11 34 11 34]      |   [ 5 29  4 70 11 34 11 34]    | [ 5 29  4 70 11 34 11 34]     
2      | [ 6  5  4 51 23 20 11 34]      | * [ 6  5  4  1 23  1 11 34]    | [ 6  5  4 48 23 27 11 34]     
3      | [ 6  5  4 79 12 34 11 34]      | * [ 6  5  4  1 12  1 11 34]    | [ 6  5  4 67 12 27 11 34]     
4      | [ 6  9  4 58 21 16 11 34]      | * [ 6  9  4  1 21  1 11 34]    | [ 6  9  4 67 21 27 11 34]     
5      | [ 6 13  4 63 17 18 11 34]      | * [ 6 13  4  1 17  1 11 34]    | [ 6 13  4 67 17 29 11 34]     
6      | [ 6 13  4 77  9 33 11 34]      | * [ 6 13  4  1  9  1 11 34]    | [ 6 13  4 60  9 29 11 34]     
7      | [ 