# Assignment 3: MuseGAN

In [None]:
import random

import musegan
import numpy as np
import torch as t
from torch.utils.data import DataLoader
from torchinfo import summary

N_BARS = 4
N_TRACKS = 5
N_STEPS_PER_BAR = 48
BATCH_SIZE = 64
Z_DIM = 32
DEVICE = "cuda:0"

t.random.manual_seed(0x0D000721)
random.seed(0x0D000721)
np.random.seed(0x0D000721)

## 3.1. Model structure

### 3.1.1. Temporal network

In [None]:
temp_net = musegan.temporal.TemporalNetwork(
    z_dimension=Z_DIM, hid_channels=1024, n_bars=N_BARS
)
summary(temp_net, input_size=(BATCH_SIZE, Z_DIM))

### 3.1.2. Bar generator

In [None]:
HID_FEATURES = 1152
HID_CHANNELS = 192
N_PITCHES = 84

bar_gen = musegan.bar_generator.BarGenerator(
    z_dimension=Z_DIM,
    hid_features=HID_FEATURES,
    hid_channels=HID_CHANNELS,
    n_steps_per_bar=N_STEPS_PER_BAR,
    n_pitches=N_PITCHES,
)
summary(bar_gen, input_size=(BATCH_SIZE, 4 * Z_DIM))

### 3.1.3. Generator

In [None]:
muse_gen = musegan.generator.MuseGenerator(
    z_dimension=Z_DIM,
    hid_channels=HID_CHANNELS * 2,
    hid_features=HID_FEATURES,
    n_tracks=N_TRACKS,
    n_bars=N_BARS,
    n_steps_per_bar=N_STEPS_PER_BAR,
    n_pitches=N_PITCHES,
)
summary(
    muse_gen,
    input_size=(
        (BATCH_SIZE, Z_DIM),
        (BATCH_SIZE, Z_DIM),
        (BATCH_SIZE, N_TRACKS, Z_DIM),
        (BATCH_SIZE, N_TRACKS, Z_DIM),
    ),
)

### 3.1.4. Discriminator

In [None]:
critic = musegan.critic.MuseCritic(
    hid_channels=128,
    n_tracks=N_TRACKS,
    n_bars=N_BARS,
    n_steps_per_bar=N_STEPS_PER_BAR,
    n_pitches=N_PITCHES,
)
summary(critic, input_size=(BATCH_SIZE, N_TRACKS, N_BARS, N_STEPS_PER_BAR, N_PITCHES))

## 3.2. Dataset

In [None]:
DATASET_PATH = "prepared/train_x_lpd_5.npz"

In [None]:
def seed_worker(_):
    worker_seed = t.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


g = t.Generator()
g.manual_seed(0x0D000721)

dataset = musegan.dataset.LPDDataset(DATASET_PATH)
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True,
    worker_init_fn=seed_worker,
    generator=g,
)
print(dataset[0].shape)
print(len(dataset))

## 3.3. Training

In [None]:
CKPT_PATH = "ckpt/"

In [None]:
muse_gen = muse_gen.to(DEVICE)
g_optimizer = t.optim.Adam(muse_gen.parameters(), lr=0.001, betas=(0.5, 0.9))
muse_gen = muse_gen.apply(musegan.utils.initialize_weights)

critic = critic.to(DEVICE)
c_optimizer = t.optim.Adam(critic.parameters(), lr=0.001, betas=(0.5, 0.9))
critic = critic.apply(musegan.utils.initialize_weights)

In [None]:
trainer = musegan.train.Trainer(muse_gen, critic, g_optimizer, c_optimizer, CKPT_PATH, DEVICE)
trainer.train(loader, epochs=10, batch_size=BATCH_SIZE, melody_groove=N_TRACKS)