# ðŸŽ¶ Music Generation - MuseGAN

In this notebook, we'll walk through the steps required to train your own MuseGAN model to generate music in the style of the Bach chorales

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.autograd import grad
from musegan_utils import notes_to_midi, draw_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 0. Parameters <a name="parameters"></a>

In [None]:
BATCH_SIZE = 64

N_BARS = 2
N_STEPS_PER_BAR = 16
MAX_PITCH = 83
N_PITCHES = MAX_PITCH + 1
Z_DIM = 32
N_TRACKS = 4  # é€šå¸¸ Bach Chorales æœ‰ 4 ä¸ªå£°éƒ¨

CRITIC_STEPS = 5
GP_WEIGHT = 10
CRITIC_LR = 0.001
GENERATOR_LR = 0.001
ADAM_BETA_1 = 0.5
ADAM_BETA_2 = 0.9
EPOCHS = 6000
LOAD_MODEL = False

## 1. Prepare the Data

In [None]:
file = os.path.join("/app/data/bach-chorales/Jsb16thSeparated.npz")
with np.load(file, allow_pickle=True) as f:
    data = f["train"]

N_SONGS = len(data)
print(f"{N_SONGS} chorales in the dataset")
chorale = data[0]
N_BEATS, N_TRACKS = chorale.shape
print(f"{N_BEATS, N_TRACKS} shape of chorale 0")
print("\nChorale 0")
print(chorale[:8])

In [None]:
# %%
two_bars = np.array([x[: (N_STEPS_PER_BAR * N_BARS)] for x in data])
two_bars = np.array(np.nan_to_num(two_bars, nan=MAX_PITCH), dtype=int)
two_bars = two_bars.reshape([N_SONGS, N_BARS, N_STEPS_PER_BAR, N_TRACKS])

# è½¬ one-hotï¼Œå¹¶æ˜ å°„åˆ° -1/1
data_binary = np.eye(N_PITCHES)[two_bars]
data_binary[data_binary == 0] = -1
data_binary = data_binary.transpose(0, 1, 2, 4, 3)  # [N, bars, steps, pitches, tracks]

dataset = TensorDataset(torch.tensor(data_binary, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

## 2. Build the GAN <a name="build"></a>

In [None]:
def conv3d_block(in_channels, out_channels, kernel_size, stride, padding):
    return nn.Sequential(
        nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding),
        nn.LeakyReLU(0.2)
    )

def convt2d_block(in_channels, out_channels, kernel_size, stride, padding, activation="relu", bn=True):
    layers = [nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding)]
    if bn:
        layers.append(nn.BatchNorm2d(out_channels, momentum=0.9))
    if activation == "relu":
        layers.append(nn.ReLU())
    elif activation == "tanh":
        layers.append(nn.Tanh())
    elif activation == "leakyrelu":
        layers.append(nn.LeakyReLU(0.2))
    return nn.Sequential(*layers)

In [None]:
class TemporalNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(Z_DIM, 1024, kernel_size=(2,1), stride=(1,1), padding=0),
            nn.BatchNorm2d(1024, momentum=0.9),
            nn.ReLU(),
            nn.ConvTranspose2d(1024, Z_DIM, kernel_size=(N_BARS-1,1), stride=(1,1), padding=0),
            nn.BatchNorm2d(Z_DIM, momentum=0.9),
            nn.ReLU()
        )
    def forward(self, z):
        x = z.view(z.size(0), Z_DIM, 1, 1)
        x = self.net(x)
        return x.view(z.size(0), N_BARS, Z_DIM)

In [None]:
class BarGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(Z_DIM*4, 1024),
            nn.BatchNorm1d(1024, momentum=0.9),
            nn.ReLU()
        )
        self.conv_blocks = nn.Sequential(
            convt2d_block(512, 512, (2,1), (2,1), (0,0)),
            convt2d_block(512, 256, (2,1), (2,1), (0,0)),
            convt2d_block(256, 256, (2,1), (2,1), (0,0)),
            convt2d_block(256, 256, (1,7), (1,7), (0,0)),
            convt2d_block(256, 1, (1,12), (1,12), (0,0), activation="tanh", bn=False)
        )
    def forward(self, z):
        x = self.fc(z)
        x = x.view(x.size(0), 512, 2, 1)
        x = self.conv_blocks(x)
        return x.view(x.size(0), 1, N_STEPS_PER_BAR, N_PITCHES, 1)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.temporal = TemporalNetwork()
        self.bar_generators = nn.ModuleList([BarGenerator() for _ in range(N_TRACKS)])
    def forward(self, chords, style, melody, groove):
        chords_over_time = self.temporal(chords)  # [B, N_BARS, Z_DIM]
        melody_over_time = []
        for t in range(N_TRACKS):
            melody_track = melody[:, t, :]
            melody_over_time.append(self.temporal(melody_track))
        bars_output = []
        for b in range(N_BARS):
            track_output = []
            c = chords_over_time[:, b, :]
            s = style
            for t in range(N_TRACKS):
                m = melody_over_time[t][:, b, :]
                g = groove[:, t, :]
                z_input = torch.cat([c, s, m, g], dim=1)
                track_output.append(self.bar_generators[t](z_input))
            bars_output.append(torch.cat(track_output, dim=-1))  # æ‹¼æŽ¥ track
        generator_output = torch.cat(bars_output, dim=1)  # æ‹¼æŽ¥ bar
        return generator_output

In [None]:
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = conv3d_block(N_TRACKS, 128, (2,1,1), (1,1,1), 0)
        self.conv2 = conv3d_block(128, 128, (N_BARS-1,1,1), (1,1,1), 0)
        self.conv3 = conv3d_block(128, 128, (1,1,12), (1,1,12), 0)
        self.conv4 = conv3d_block(128, 128, (1,1,7), (1,1,7), 0)
        self.conv5 = conv3d_block(128, 128, (1,2,1), (1,2,1), 0)
        self.conv6 = conv3d_block(128, 128, (1,2,1), (1,2,1), 0)
        self.conv7 = conv3d_block(128, 256, (1,4,1), (1,2,1), 0)
        self.conv8 = conv3d_block(256, 512, (1,3,1), (1,2,1), 0)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(512, 1024)
        self.leaky = nn.LeakyReLU(0.2)
        self.fc2 = nn.Linear(1024, 1)
    def forward(self, x):
        x = x.permute(0,4,1,2,3)  # [B, tracks, bars, steps, pitch] -> [B,C,D,H,W]
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.conv6(x)
        x = self.conv7(x)
        x = self.conv8(x)
        x = self.flatten(x)
        x = self.leaky(self.fc1(x))
        return self.fc2(x)

In [None]:
class MuseGAN:
    def __init__(self, generator, critic, latent_dim, critic_steps, gp_weight):
        self.generator = generator.to(device)
        self.critic = critic.to(device)
        self.latent_dim = latent_dim
        self.critic_steps = critic_steps
        self.gp_weight = gp_weight
        self.g_optimizer = optim.Adam(generator.parameters(), lr=GENERATOR_LR, betas=(ADAM_BETA_1, ADAM_BETA_2))
        self.c_optimizer = optim.Adam(critic.parameters(), lr=CRITIC_LR, betas=(ADAM_BETA_1, ADAM_BETA_2))
    def gradient_penalty(self, real, fake):
        alpha = torch.rand(real.size(0), 1, 1, 1, 1).to(device)
        interpolated = (alpha * real + (1-alpha) * fake).requires_grad_(True)
        pred = self.critic(interpolated)
        grads = grad(pred.sum(), interpolated, create_graph=True)[0]
        grads = grads.view(grads.size(0), -1)
        gp = ((grads.norm(2, dim=1) - 1)**2).mean()
        return gp
    def train_step(self, real):
        batch_size = real.size(0)
        real = real.to(device)
        # Train Critic
        for _ in range(self.critic_steps):
            chords = torch.randn(batch_size, Z_DIM).to(device)
            style = torch.randn(batch_size, Z_DIM).to(device)
            melody = torch.randn(batch_size, N_TRACKS, Z_DIM).to(device)
            groove = torch.randn(batch_size, N_TRACKS, Z_DIM).to(device)
            fake = self.generator(chords, style, melody, groove)
            real_pred = self.critic(real)
            fake_pred = self.critic(fake.detach())
            c_loss = fake_pred.mean() - real_pred.mean() + self.gp_weight * self.gradient_penalty(real, fake)
            self.c_optimizer.zero_grad()
            c_loss.backward()
            self.c_optimizer.step()
        # Train Generator
        chords = torch.randn(batch_size, Z_DIM).to(device)
        style = torch.randn(batch_size, Z_DIM).to(device)
        melody = torch.randn(batch_size, N_TRACKS, Z_DIM).to(device)
        groove = torch.randn(batch_size, N_TRACKS, Z_DIM).to(device)
        fake = self.generator(chords, style, melody, groove)
        g_loss = -self.critic(fake).mean()
        self.g_optimizer.zero_grad()
        g_loss.backward()
        self.g_optimizer.step()
        return c_loss.item(), g_loss.item()
    def generate(self, num_scores):
        chords = torch.randn(num_scores, Z_DIM).to(device)
        style = torch.randn(num_scores, Z_DIM).to(device)
        melody = torch.randn(num_scores, N_TRACKS, Z_DIM).to(device)
        groove = torch.randn(num_scores, N_TRACKS, Z_DIM).to(device)
        with torch.no_grad():
            generated = self.generator(chords, style, melody, groove).cpu().numpy()
        return generated

In [None]:
# Create a MuseGAN
generator = Generator()
critic = Critic()
musegan = MuseGAN(generator, critic, Z_DIM, CRITIC_STEPS, GP_WEIGHT)

## 3. Train the MuseGAN <a name="train"></a>

In [None]:
for epoch in range(EPOCHS):
    c_losses, g_losses = [], []
    for batch in dataloader:
        real_batch = batch[0]
        c_loss, g_loss = musegan.train_step(real_batch)
        c_losses.append(c_loss)
        g_losses.append(g_loss)
    print(f"Epoch {epoch+1}/{EPOCHS} - C Loss: {np.mean(c_losses):.4f}, G Loss: {np.mean(g_losses):.4f}")
    if (epoch+1) % 1 == 0:
        gen_music = musegan.generate(1)
        notes_to_midi(gen_music, N_BARS, N_TRACKS, N_STEPS_PER_BAR, filename=f"output_{epoch:04d}")
        draw_score(gen_music, 0)

In [None]:
# %%
num_scores = 1


## Changing Chord Noise

In [None]:
# Chord
chords_noise = torch.randn(num_scores, Z_DIM)
style_noise = torch.randn(num_scores, Z_DIM)
melody_noise = torch.randn(num_scores, N_TRACKS, Z_DIM)
groove_noise = torch.randn(num_scores, N_TRACKS, Z_DIM)
generated_music = generator(chords_noise, style_noise, melody_noise, groove_noise).detach().cpu().numpy()
draw_score(generated_music, 0)
notes_to_midi(generated_music, N_BARS, N_TRACKS, N_STEPS_PER_BAR, filename="output_midi_chords_changed")

# Changing Style Noise

In [None]:
# Style
style_noise2 = torch.randn(num_scores, Z_DIM)
generated_music = generator(chords_noise, style_noise2, melody_noise, groove_noise).detach().cpu().numpy()
draw_score(generated_music, 0)
notes_to_midi(generated_music, N_BARS, N_TRACKS, N_STEPS_PER_BAR, filename="output_midi_style_changed")

## Changing Melody Noise

In [None]:
# Melody
melody_noise2 = melody_noise.clone()
melody_noise2[:, 0, :] = torch.randn(num_scores, Z_DIM)
generated_music = generator(chords_noise, style_noise, melody_noise2, groove_noise).detach().cpu().numpy()
draw_score(generated_music, 0)
notes_to_midi(generated_music, N_BARS, N_TRACKS, N_STEPS_PER_BAR, filename="output_midi_melody_changed")

## Changing groove noise

In [None]:
# Groove
groove_noise2 = groove_noise.clone()
groove_noise2[:, -1, :] = torch.randn(num_scores, Z_DIM)
generated_music = generator(chords_noise, style_noise, melody_noise, groove_noise2).detach().cpu().numpy()
draw_score(generated_music, 0)
notes_to_midi(generated_music, N_BARS, N_TRACKS, N_STEPS_PER_BAR, filename="output_midi_groove_changed")