In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
import os

# Make sure these imports match your file names
from hierarchical_vae import HierarchicalDrumVAE
from dataset import DrumPatternDataset

# --- Configuration ---
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_PATH = "results/best_model.pth"
DATA_DIR = "data/drums"
RESULTS_DIR = "results"

# --- Create Directories ---
os.makedirs(os.path.join(RESULTS_DIR, "generated_patterns"), exist_ok=True)
os.makedirs(os.path.join(RESULTS_DIR, "latent_analysis"), exist_ok=True)

# --- Load Model ---
model = HierarchicalDrumVAE(z_high_dim=4, z_low_dim=12).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()
print("✅ Model loaded successfully.")

# --- Load Dataset ---
dataset = DrumPatternDataset(DATA_DIR, split='all')
print(f"✅ Dataset with {len(dataset)} patterns loaded.")

# --- Helper function for plotting and saving ---
def save_pattern(pattern, path, title=""):
    plt.figure(figsize=(6, 3))
    plt.imshow(pattern.T, aspect='auto', origin='lower', cmap='gray_r', interpolation='nearest')
    plt.xlabel("Timestep")
    plt.ylabel("Instrument")
    plt.yticks(np.arange(9), dataset.instrument_names, fontsize=8)
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path)
    plt.close() # Close the plot to save memory

In [None]:
print("🎵 Generating 10 samples for each style...")

# Get a representative style vector (z_high) for each genre
style_latents = {}
with torch.no_grad():
    for i, style_name in enumerate(dataset.style_names):
        # Find the first pattern with this style index
        pattern_idx = np.where(dataset.styles == i)[0][0]
        pattern, _, _ = dataset[pattern_idx]

        # Encode it to get the style vector
        _, _, _, _, _, z_high = model.encode_hierarchy(pattern.unsqueeze(0).to(DEVICE))
        style_latents[style_name] = z_high

# Generate and save patterns
for style_name, z_high_style in style_latents.items():
    style_dir = os.path.join(RESULTS_DIR, "generated_patterns", style_name)
    os.makedirs(style_dir, exist_ok=True)

    with torch.no_grad():
        for i in range(10):
            # Keep the style (z_high) fixed, but sample a new rhythm (z_low) from the prior
            z_low_sample = torch.randn(1, model.z_low_dim).to(DEVICE)

            # Decode to generate the pattern
            logits = model.decode_hierarchy(z_high_style, z_low_sample)
            generated_pattern = (torch.sigmoid(logits) > 0.5).float().squeeze(0).cpu().numpy()

            # Save the resulting pattern image
            save_path = os.path.join(style_dir, f"sample_{i+1}.png")
            save_pattern(generated_pattern, save_path, title=f"{style_name.capitalize()} Sample {i+1}")

print("✅ Saved all generated style samples to 'results/generated_patterns/'")

In [None]:
print("🔄 Generating interpolation sequence (Jazz to Rock)...")

# 1. Select a Jazz and a Rock pattern
jazz_idx = np.where(dataset.styles == dataset.style_names.index('jazz'))[0][0]
rock_idx = np.where(dataset.styles == dataset.style_names.index('rock'))[0][0]
jazz_pattern, _, _ = dataset[jazz_idx]
rock_pattern, _, _ = dataset[rock_idx]

# 2. Encode them to get their latent vectors
with torch.no_grad():
    p1 = jazz_pattern.unsqueeze(0).to(DEVICE)
    p2 = rock_pattern.unsqueeze(0).to(DEVICE)
    _, _, z_low1, _, _, z_high1 = model.encode_hierarchy(p1)
    _, _, z_low2, _, _, z_high2 = model.encode_hierarchy(p2)

# 3. Interpolate, decode, and save each step
interp_dir = os.path.join(RESULTS_DIR, "generated_patterns", "interpolation_jazz_rock")
os.makedirs(interp_dir, exist_ok=True)
n_steps = 10

for i, alpha in enumerate(np.linspace(0, 1, n_steps)):
    # Interpolate both style (z_high) and rhythm (z_low)
    z_high_interp = (1 - alpha) * z_high1 + alpha * z_high2
    z_low_interp = (1 - alpha) * z_low1 + alpha * z_low2

    with torch.no_grad():
        logits = model.decode_hierarchy(z_high_interp, z_low_interp)
        recon_pattern = (torch.sigmoid(logits) > 0.5).float().squeeze(0).cpu().numpy()

    save_path = os.path.join(interp_dir, f"interp_step_{i}.png")
    save_pattern(recon_pattern, save_path, title=f"Interpolation Step {i} (alpha={alpha:.2f})")

print(f"✅ Saved interpolation sequence to '{interp_dir}'")

In [None]:
print("🎭 Generating style transfer example (Hip-hop rhythm -> Latin style)...")

# 1. Select a Hip-hop pattern (for rhythm) and a Latin pattern (for style)
hiphop_idx = np.where(dataset.styles == dataset.style_names.index('hiphop'))[0][0]
latin_idx = np.where(dataset.styles == dataset.style_names.index('latin'))[0][0]
hiphop_pattern, _, _ = dataset[hiphop_idx]
latin_pattern, _, _ = dataset[latin_idx]

# 2. Encode both to get their respective latent vectors
with torch.no_grad():
    p_rhythm = hiphop_pattern.unsqueeze(0).to(DEVICE)
    p_style = latin_pattern.unsqueeze(0).to(DEVICE)
    # Get z_low from the hip-hop beat
    _, _, z_low_rhythm, _, _, _ = model.encode_hierarchy(p_rhythm)
    # Get z_high from the latin beat
    _, _, _, _, _, z_high_style = model.encode_hierarchy(p_style)

# 3. Combine them and decode
with torch.no_grad():
    logits = model.decode_hierarchy(z_high_style, z_low_rhythm)
    style_transfer_pattern = (torch.sigmoid(logits) > 0.5).float().squeeze(0).cpu().numpy()

# 4. Save all three patterns for comparison
transfer_dir = os.path.join(RESULTS_DIR, "generated_patterns", "style_transfer")
os.makedirs(transfer_dir, exist_ok=True)
save_pattern(hiphop_pattern.numpy(), os.path.join(transfer_dir, "01_source_rhythm_hiphop.png"), "Source Rhythm (Hip-hop)")
save_pattern(latin_pattern.numpy(), os.path.join(transfer_dir, "02_source_style_latin.png"), "Source Style (Latin)")
save_pattern(style_transfer_pattern, os.path.join(transfer_dir, "03_style_transfer_result.png"), "Result: Hip-hop Rhythm in Latin Style")

print(f"✅ Saved style transfer examples to '{transfer_dir}'")

In [None]:
print("🔬 Performing latent space analysis...")

# 1. t-SNE visualization (re-generated from analyze_latent.py)
# This part is identical to the script to ensure the deliverable is created
all_z_high = []
all_styles = []
with torch.no_grad():
    for i in range(len(dataset)):
        pattern, style, _ = dataset[i]
        _, _, _, _, _, z_high = model.encode_hierarchy(pattern.unsqueeze(0).to(DEVICE))
        all_z_high.append(z_high.cpu().numpy())
        all_styles.append(style)

all_z_high = np.concatenate(all_z_high, axis=0)
all_styles = np.array(all_styles)

tsne = TSNE(n_components=2, random_state=42, perplexity=30, learning_rate='auto', init='pca')
z_high_embedded = tsne.fit_transform(all_z_high)

plt.figure(figsize=(10, 8))
scatter = plt.scatter(z_high_embedded[:, 0], z_high_embedded[:, 1], c=all_styles, cmap='viridis', alpha=0.7)
plt.legend(handles=scatter.legend_elements()[0], labels=dataset.style_names, title="Styles")
plt.title("t-SNE Visualization of High-Level Latent Space (z_high)")
plt.savefig(os.path.join(RESULTS_DIR, "latent_analysis", "tsne_z1_space.png"))
plt.close()
print("✅ Saved t-SNE plot.")

# 2. Dimension Interpretation / Disentanglement Analysis
# We test how a single z_low dimension affects the output for a fixed style
z_high_electronic = style_latents['electronic']
z_low_base = torch.zeros(1, model.z_low_dim).to(DEVICE)
dim_to_vary = 5 # You can experiment with changing this from 0 to 11
n_steps = 7

fig, axes = plt.subplots(1, n_steps, figsize=(15, 3))
for i, val in enumerate(np.linspace(-2.5, 2.5, n_steps)):
    z_low_varied = z_low_base.clone()
    z_low_varied[0, dim_to_vary] = val

    with torch.no_grad():
        logits = model.decode_hierarchy(z_high_electronic, z_low_varied)
        pattern = (torch.sigmoid(logits) > 0.5).float().squeeze(0).cpu().numpy()

    density = np.sum(pattern)
    ax = axes[i]
    ax.imshow(pattern.T, aspect='auto', origin='lower', cmap='gray_r')
    ax.set_title(f"Dim {dim_to_vary} = {val:.1f}\nDensity={int(density)}")
    ax.set_xticks([])
    ax.set_yticks([])

fig.suptitle(f"Dimension Interpretation: Varying z_low[{dim_to_vary}] for Electronic Style")
plt.savefig(os.path.join(RESULTS_DIR, "latent_analysis", f"dimension_interpretation_dim{dim_to_vary}.png"))
plt.close()
print("✅ Saved dimension interpretation plot.")