Create a simple music generation model using a VAE trained on a MIDI file dataset

In [None]:
#Step 1
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [None]:
# Step 2
num_samples = 500
sequence_length = 32
num_notes = 88
music_data = np.random.randint(
    0, num_notes, size=(num_samples, sequence_length)
)
music_data = torch.tensor(music_data).float()/num_notes

In [None]:
# Step 3: Define VAE Model for Music
class MusicVAE(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.fc1 = nn.Linear(sequence_length, 64)
        self.fc_mu = nn.Linear(64, 16)
        self.fc_logvar = nn.Linear(64, 16)

        # Decoder
        self.fc2 = nn.Linear(16, 64)
        self.fc3 = nn.Linear(64, sequence_length)
#        Step 4: Encoder Function
    def encode(self, x):
        h = torch.relu(self.fc1(x))
        return self.fc_mu(h), self.fc_logvar(h)

# Step 5: Decoder Function
    def decode(self, z):
        h = torch.relu(self.fc2(z))
        return torch.sigmoid(self.fc3(h))

# Step 6: Forward Pass (VERY IMPORTANT)
    def forward(self, x):
        # Encode
        mu, logvar = self.encode(x)

        # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std

  # Decode
        reconstructed = self.decode(z)

        return reconstructed, mu, logvar

# Step 7: Loss Function
def vae_loss(recon_x, x, mu, logvar):
    reconstruction_loss = nn.functional.mse_loss(recon_x, x)
    kl_loss = -0.5 * torch.mean(
        1 + logvar - mu.pow(2) - logvar.exp()
    )
    return reconstruction_loss + kl_loss

In [None]:
#Step 8
model  = MusicVAE()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Step 9: Training Loop
epochs = 50

for epoch in range(epochs):

    recon, mu, logvar = model(music_data)
    loss = vae_loss(recon, music_data, mu, logvar)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss = {loss.item():.4f}")


Epoch 0, Loss = 0.0848
Epoch 10, Loss = 0.0846
Epoch 20, Loss = 0.0845
Epoch 30, Loss = 0.0843
Epoch 40, Loss = 0.0844


In [None]:
# Step 10: Generate New Music Sequence
with torch.no_grad():

    # Sample from latent space
    z = torch.randn(1, 16)

    # Generate music
    generated_music = model.decode(z)

# Convert back to note values
generated_music = (generated_music * num_notes).int()

print("Generated Music Note Sequence:")
print(generated_music.tolist())

Generated Music Note Sequence:
[[43, 43, 42, 46, 44, 43, 41, 45, 42, 41, 44, 38, 39, 42, 46, 42, 43, 41, 41, 41, 46, 46, 45, 44, 46, 38, 46, 41, 47, 43, 41, 46]]


In [None]:
import numpy as np
from IPython.display import Audio

sample_rate = 22050

# -------------------------------
# Melody (Been / Snake style)
# -------------------------------
def melody_note(freq, duration=0.3):
    t = np.linspace(0, duration, int(sample_rate * duration), False)

    # Vibrato for snake effect
    vibrato = 0.02 * np.sin(2 * np.pi * 6 * t)
    wave = np.sin(2 * np.pi * freq * (1 + vibrato) * t)

    # Soft attack envelope
    attack = int(0.1 * len(t))
    envelope = np.ones(len(t))
    envelope[:attack] = np.linspace(0, 1, attack)

    return 0.4 * wave * envelope


# -------------------------------
# Harmonium Drone (Background)
# -------------------------------
def harmonium_drone(freq, duration):
    t = np.linspace(0, duration, int(sample_rate * duration), False)

    wave = (
        np.sin(2 * np.pi * freq * t) +
        0.5 * np.sin(2 * np.pi * freq * 2 * t)
    )

    return 0.2 * wave


# -------------------------------
# Drum / Percussion
# -------------------------------
def drum_hit(duration=0.08):
    t = np.linspace(0, duration, int(sample_rate * duration), False)
    noise = np.random.randn(len(t))
    envelope = np.exp(-20 * t)
    return 0.3 * noise * envelope


# -------------------------------
# Naagin Tune (MIDI Notes)
# -------------------------------
nagin_tune = [
    69, 72, 69, 67,
    69, 72, 69, 67,
    69, 72, 76,
    74, 72,
    69, 67, 65,
    67, 69
]

# -------------------------------
# Build Melody + Drums
# -------------------------------
melody = np.array([], dtype=np.float32)
drums = np.array([], dtype=np.float32)

note_duration = 0.3

for note in nagin_tune:
    freq = 440 * (2 ** ((note - 69) / 12))
    melody = np.concatenate((melody, melody_note(freq, note_duration)))
    drums = np.concatenate((drums, drum_hit(duration=note_duration)))

# -------------------------------
# Harmonium Drone Length
# -------------------------------
total_duration = len(melody) / sample_rate
drone = harmonium_drone(440, total_duration)  # Sa / Base note

# -------------------------------
# Final Mix
# -------------------------------
audio = melody + drone[:len(melody)] + drums[:len(melody)]

Audio(audio, rate=sample_rate)
